Skip to content

Commit

Permalink
Use NIOThreadPool.singleton instead of .createNew in multipart up…
Browse files Browse the repository at this point in the history
…load (#695)

* Use NIOThreadPool.singleton

* Revert some changes swiftformat made

* Get NIOThreadPool.singleton on background thread

* More logging

* Remove logging
  • Loading branch information
adam-fowler authored Oct 6, 2023
1 parent 37a8ece commit 92623a6
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 70 deletions.
3 changes: 2 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ let package = Package(
.library(name: "SotoXRay", targets: ["SotoXRay"])
],
dependencies: [
.package(url: "https://github.com/soto-project/soto-core.git", branch: "main")
.package(url: "https://github.com/soto-project/soto-core.git", branch: "main"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"),
],
targets: [
.target(name: "SotoACM", dependencies: [.product(name: "SotoCore", package: "soto-core")], path: "./Sources/Soto/Services/ACM"),
Expand Down
36 changes: 29 additions & 7 deletions Sources/Soto/Extensions/S3/S3+multipart+async.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

import Atomics
import Dispatch
import Logging
import NIOCore
import NIOPosix
Expand Down Expand Up @@ -123,12 +124,12 @@ extension S3 {
filename: String,
logger: Logger = AWSClient.loggingDisabled,
on eventLoop: EventLoop? = nil,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: @escaping (Double) throws -> Void = { _ in }
) async throws -> Int64 {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()

let threadPool = threadPoolProvider.create()
let threadPool = await threadPoolProvider.create()
let fileIO = NonBlockingFileIO(threadPool: threadPool)
let fileHandle = try await fileIO.openFile(path: filename, mode: .write, flags: .allowFileCreation(), eventLoop: eventLoop).get()
let progressValue = ManagedAtomic(0)
Expand Down Expand Up @@ -178,7 +179,7 @@ extension S3 {
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
on eventLoop: EventLoop? = nil,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: @escaping @Sendable (Double) throws -> Void = { _ in }
) async throws -> CompleteMultipartUploadOutput {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()
Expand Down Expand Up @@ -235,7 +236,7 @@ extension S3 {
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
on eventLoop: EventLoop? = nil,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: @escaping (Double) throws -> Void = { _ in }
) async throws -> CompleteMultipartUploadOutput {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()
Expand Down Expand Up @@ -352,10 +353,10 @@ extension S3 {
filename: String,
logger: Logger,
on eventLoop: EventLoop,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
uploadCallback: @escaping (NIOFileHandle, FileRegion, NonBlockingFileIO) async throws -> CompleteMultipartUploadOutput
) async throws -> CompleteMultipartUploadOutput {
let threadPool = threadPoolProvider.create()
let threadPool = await threadPoolProvider.create()
let fileIO = NonBlockingFileIO(threadPool: threadPool)
let (fileHandle, fileRegion) = try await fileIO.openFile(path: filename, eventLoop: eventLoop).get()

Expand Down Expand Up @@ -404,7 +405,7 @@ extension S3 {
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
on eventLoop: EventLoop? = nil,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: (@Sendable (Int) throws -> Void)? = nil
) async throws -> CompleteMultipartUploadOutput where ByteBufferSequence.Element == ByteBuffer {
// initialize multipart upload
Expand Down Expand Up @@ -682,6 +683,27 @@ extension S3 {

@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
extension S3.ThreadPoolProvider {
func create() async -> NIOThreadPool {
switch self {
case .createNew:
return await withUnsafeContinuation { (cont: UnsafeContinuation<NIOThreadPool, Never>) in
DispatchQueue.global(qos: .background).async {
let threadPool = NIOThreadPool(numberOfThreads: NonBlockingFileIO.defaultThreadPoolSize)
threadPool.start()
cont.resume(returning: threadPool)
}
}
case .singleton:
return await withUnsafeContinuation { (cont: UnsafeContinuation<NIOThreadPool, Never>) in
DispatchQueue.global(qos: .background).async {
cont.resume(returning: .singleton)
}
}
case .shared(let sharedPool):
return sharedPool
}
}

/// async version of destroy
func destroy(_ threadPool: NIOThreadPool) async throws {
if case .createNew = self {
Expand Down
127 changes: 73 additions & 54 deletions Sources/Soto/Extensions/S3/S3+multipart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

import Dispatch
import NIO
import SotoCore

Expand All @@ -32,18 +33,33 @@ extension S3ErrorType {

extension S3 {
public enum ThreadPoolProvider {
/// Create new thread pool
@available(*, deprecated, message: "Use .singleton instead")
case createNew
/// Use thread pool supplied
case shared(NIOThreadPool)
/// singleton
case singleton

func create() -> NIOThreadPool {
let threadPool: NIOThreadPool
func create(eventLoop: EventLoop) -> EventLoopFuture<NIOThreadPool> {
switch self {
case .createNew:
threadPool = NIOThreadPool(numberOfThreads: NonBlockingFileIO.defaultThreadPoolSize)
threadPool.start()
return threadPool
let promise = eventLoop.makePromise(of: NIOThreadPool.self)
DispatchQueue.global(qos: .background).async {
let threadPool = NIOThreadPool(numberOfThreads: NonBlockingFileIO.defaultThreadPoolSize)
threadPool.start()
promise.completeWith(.success(threadPool))
}
return promise.futureResult
case .singleton:
let promise = eventLoop.makePromise(of: NIOThreadPool.self)
DispatchQueue.global(qos: .background).async {
promise.completeWith(.success(.singleton))
}

return promise.futureResult
case .shared(let sharedPool):
return sharedPool
return eventLoop.makeSucceededFuture(sharedPool)
}
}

Expand Down Expand Up @@ -160,39 +176,40 @@ extension S3 {
filename: String,
logger: Logger = AWSClient.loggingDisabled,
on eventLoop: EventLoop? = nil,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: @escaping (Double) throws -> Void = { _ in }
) -> EventLoopFuture<Int64> {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()

let threadPool = threadPoolProvider.create()
let fileIO = NonBlockingFileIO(threadPool: threadPool)
return threadPoolProvider.create(eventLoop: eventLoop).flatMap { threadPool in
let fileIO = NonBlockingFileIO(threadPool: threadPool)

return fileIO.openFile(path: filename, mode: .write, flags: .allowFileCreation(), eventLoop: eventLoop).flatMap {
fileHandle -> EventLoopFuture<Int64> in
var progressValue: Int64 = 0
return fileIO.openFile(path: filename, mode: .write, flags: .allowFileCreation(), eventLoop: eventLoop).flatMap {
fileHandle -> EventLoopFuture<Int64> in
var progressValue: Int64 = 0

let download = self.multipartDownload(input, partSize: partSize, logger: logger, on: eventLoop) { byteBuffer, fileSize, eventLoop in
let bufferSize = byteBuffer.readableBytes
return fileIO.write(fileHandle: fileHandle, buffer: byteBuffer, eventLoop: eventLoop).flatMapThrowing { _ in
progressValue += Int64(bufferSize)
try progress(Double(progressValue) / Double(fileSize))
let download = self.multipartDownload(input, partSize: partSize, logger: logger, on: eventLoop) { byteBuffer, fileSize, eventLoop in
let bufferSize = byteBuffer.readableBytes
return fileIO.write(fileHandle: fileHandle, buffer: byteBuffer, eventLoop: eventLoop).flatMapThrowing { _ in
progressValue += Int64(bufferSize)
try progress(Double(progressValue) / Double(fileSize))
}
}
}

download.whenComplete { _ in
threadPoolProvider.destroy(threadPool)
download.whenComplete { _ in
threadPoolProvider.destroy(threadPool)
}
return
download
.flatMapErrorThrowing { error in
try fileHandle.close()
throw error
}
.flatMapThrowing { rt in
try fileHandle.close()
return rt
}
}
return
download
.flatMapErrorThrowing { error in
try fileHandle.close()
throw error
}
.flatMapThrowing { rt in
try fileHandle.close()
return rt
}
}
}

Expand Down Expand Up @@ -294,8 +311,8 @@ extension S3 {
) -> EventLoopFuture<CompleteMultipartUploadOutput> {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()

var progressAmount: Int = 0
var prevProgressAmount: Int = 0
var progressAmount = 0
var prevProgressAmount = 0

return self.multipartUploadFromStream(input, abortOnFail: abortOnFail, logger: logger, on: eventLoop) { eventLoop in
let size = min(partSize, uploadSize - progressAmount)
Expand Down Expand Up @@ -337,11 +354,12 @@ extension S3 {
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
on eventLoop: EventLoop? = nil,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: @escaping (Double) throws -> Void = { _ in }
) -> EventLoopFuture<CompleteMultipartUploadOutput> {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()

logger.debug("MultipartUpload of \(filename)")
return openFileForMultipartUpload(
filename: filename,
logger: logger,
Expand Down Expand Up @@ -457,8 +475,8 @@ extension S3 {
) -> EventLoopFuture<CompleteMultipartUploadOutput> {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()

var progressAmount: Int = 0
var prevProgressAmount: Int = 0
var progressAmount = 0
var prevProgressAmount = 0

return self.resumeMultipartUpload(
input,
Expand Down Expand Up @@ -511,7 +529,7 @@ extension S3 {
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
on eventLoop: EventLoop? = nil,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: @escaping (Double) throws -> Void = { _ in }
) -> EventLoopFuture<CompleteMultipartUploadOutput> {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()
Expand Down Expand Up @@ -552,7 +570,7 @@ extension S3 {
) -> EventLoopFuture<CompleteMultipartUploadOutput> {
let eventLoop = eventLoop ?? self.client.eventLoopGroup.next()

var uploadId: String = ""
var uploadId = ""

// initialize multipart upload
let request: CreateMultipartUploadRequest = .init(acl: input.acl, bucket: input.bucket, cacheControl: input.cacheControl, contentDisposition: input.contentDisposition, contentEncoding: input.contentEncoding, contentLanguage: input.contentLanguage, contentType: input.contentType, expectedBucketOwner: input.expectedBucketOwner, expires: input.expires, grantFullControl: input.grantFullControl, grantRead: input.grantRead, grantReadACP: input.grantReadACP, grantWriteACP: input.grantWriteACP, key: input.key, metadata: input.metadata, objectLockLegalHoldStatus: input.objectLockLegalHoldStatus, objectLockMode: input.objectLockMode, objectLockRetainUntilDate: input.objectLockRetainUntilDate, requestPayer: input.requestPayer, serverSideEncryption: input.serverSideEncryption, sseCustomerAlgorithm: input.sseCustomerAlgorithm, sseCustomerKey: input.sseCustomerKey, sseCustomerKeyMD5: input.sseCustomerKeyMD5, ssekmsEncryptionContext: input.ssekmsEncryptionContext, ssekmsKeyId: input.ssekmsKeyId, storageClass: input.storageClass, tagging: input.tagging, websiteRedirectLocation: input.websiteRedirectLocation)
Expand Down Expand Up @@ -638,30 +656,31 @@ extension S3 {
filename: String,
logger: Logger,
on eventLoop: EventLoop,
threadPoolProvider: ThreadPoolProvider = .createNew,
threadPoolProvider: ThreadPoolProvider = .singleton,
uploadCallback: @escaping (NIOFileHandle, FileRegion, NonBlockingFileIO) -> EventLoopFuture<CompleteMultipartUploadOutput>
) -> EventLoopFuture<CompleteMultipartUploadOutput> {
let threadPool = threadPoolProvider.create()
let fileIO = NonBlockingFileIO(threadPool: threadPool)
return threadPoolProvider.create(eventLoop: eventLoop).flatMap { threadPool in
let fileIO = NonBlockingFileIO(threadPool: threadPool)

return fileIO.openFile(path: filename, eventLoop: eventLoop).flatMap {
fileHandle, fileRegion -> EventLoopFuture<CompleteMultipartUploadOutput> in
return fileIO.openFile(path: filename, eventLoop: eventLoop).flatMap {
fileHandle, fileRegion -> EventLoopFuture<CompleteMultipartUploadOutput> in

logger.debug("Open file \(filename)")
logger.debug("Open file \(filename)")

let uploadFuture = uploadCallback(fileHandle, fileRegion, fileIO)
let uploadFuture = uploadCallback(fileHandle, fileRegion, fileIO)

uploadFuture.whenComplete { _ in
threadPoolProvider.destroy(threadPool)
}
return
uploadFuture.flatMapErrorThrowing { error in
try fileHandle.close()
throw error
}.flatMapThrowing { rt in
try fileHandle.close()
return rt
uploadFuture.whenComplete { _ in
threadPoolProvider.destroy(threadPool)
}
return
uploadFuture.flatMapErrorThrowing { error in
try fileHandle.close()
throw error
}.flatMapThrowing { rt in
try fileHandle.close()
return rt
}
}
}
}

Expand Down
12 changes: 6 additions & 6 deletions Tests/SotoTests/Services/S3/S3Tests+async.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ class S3AsyncTests: XCTestCase {
print("Connecting to AWS")
}

Self.client = AWSClient(credentialProvider: TestEnvironment.credentialProvider, middlewares: TestEnvironment.middlewares, httpClientProvider: .createNew)
Self.s3 = S3(
client: Self.client,
self.client = AWSClient(credentialProvider: TestEnvironment.credentialProvider, middlewares: TestEnvironment.middlewares, httpClientProvider: .createNew)
self.s3 = S3(
client: self.client,
region: .useast1,
endpoint: TestEnvironment.getEndPoint(environment: "LOCALSTACK_ENDPOINT")
)
Self.randomBytes = Self.createRandomBuffer(size: 23 * 1024 * 1024)
self.randomBytes = self.createRandomBuffer(size: 23 * 1024 * 1024)
}

override class func tearDown() {
XCTAssertNoThrow(try Self.client.syncShutdown())
XCTAssertNoThrow(try self.client.syncShutdown())
}

static func createRandomBuffer(size: Int) -> Data {
Expand Down Expand Up @@ -408,7 +408,7 @@ class S3AsyncTests: XCTestCase {
try XCTSkipIf(TestEnvironment.isUsingLocalstack)

let name = TestEnvironment.generateResourceName()
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)
let httpClient = HTTPClient(eventLoopGroupProvider: .singleton)
defer { XCTAssertNoThrow(try httpClient.syncShutdown()) }
let s3Url = URL(string: "https://\(name).s3.us-east-1.amazonaws.com/\(name)!=%25+/*()_.txt")!

Expand Down
2 changes: 1 addition & 1 deletion Tests/SotoTests/Services/S3/S3Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ class S3Tests: XCTestCase {
try XCTSkipIf(TestEnvironment.isUsingLocalstack)

let name = TestEnvironment.generateResourceName()
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)
let httpClient = HTTPClient(eventLoopGroupProvider: .singleton)
defer { XCTAssertNoThrow(try httpClient.syncShutdown()) }
let s3Url = URL(string: "https://\(name).s3.us-east-1.amazonaws.com/\(name)!=%25+/*()_.txt")!

Expand Down
3 changes: 2 additions & 1 deletion scripts/templates/generate-package/Package.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ let package = Package(
{{/targets}}
],
dependencies: [
.package(url: "https://github.com/soto-project/soto-core.git", branch: "main")
.package(url: "https://github.com/soto-project/soto-core.git", branch: "main"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"),
],
targets: [
{{#targets}}
Expand Down

0 comments on commit 92623a6

Please sign in to comment.