diff --git a/Package.swift b/Package.swift index bd22724..ee330fa 100644 --- a/Package.swift +++ b/Package.swift @@ -32,8 +32,8 @@ let package = Package( ], path: "Sources/F5TTS", resources: [ - .copy("mel_filters.npy"), - .copy("test_en_1_ref_short.wav") + .copy("Resources/test_en_1_ref_short.wav"), + .copy("Resources/mel_filters.npy") ] ), .executableTarget( diff --git a/README.md b/README.md index e5efa12..337613e 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ -# F5 TTS for Swift (WIP) +# F5 TTS for Swift Implementation of [F5-TTS](https://arxiv.org/abs/2410.06885) in Swift, using the [MLX Swift](https://github.com/ml-explore/mlx-swift) framework. You can listen to a [sample here](https://s3.amazonaws.com/lucasnewman.datasets/f5tts/sample.wav) that was generated in ~11 seconds on an M3 Max MacBook Pro. See the [Python repository](https://github.com/lucasnewman/f5-tts-mlx) for additional details on the model architecture. + This repository is based on the original Pytorch implementation available [here](https://github.com/SWivid/F5-TTS). @@ -19,21 +20,29 @@ A pretrained model is available [on Huggingface](https://hf.co/lucasnewman/f5-tt ## Usage ```swift -import Vocos import F5TTS let f5tts = try await F5TTS.fromPretrained(repoId: "lucasnewman/f5-tts-mlx") -let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx") // if decoding to audio output -let inputAudio = MLXArray(...) +let generatedAudio = try await f5tts.generate(text: "The quick brown fox jumped over the lazy dog.") +``` + +The result is an MLXArray with 24kHz audio samples. + +If you want to use your own reference audio sample, make sure it's a mono, 24kHz wav file of around 5-10 seconds: + +```swift +let generatedAudio = try await f5tts.generate( + text: "The quick brown fox jumped over the lazy dog.", + referenceAudioURL: ..., + referenceAudioText: "This is the caption for the reference audio." +) +``` + +You can convert an audio file to the correct format with ffmpeg like this: -let (outputAudio, _) = f5tts.sample( - cond: inputAudio, - text: ["This is the caption for the reference audio and generation text."], - duration: ..., - vocoder: vocos.decode) { progress in - print("Progress: \(Int(progress * 100))%") - } +```bash +ffmpeg -i /path/to/audio.wav -ac 1 -ar 24000 -sample_fmt s16 -t 10 /path/to/output_audio.wav ``` ## Appreciation diff --git a/Sources/F5TTS/CFM.swift b/Sources/F5TTS/F5TTS.swift similarity index 62% rename from Sources/F5TTS/CFM.swift rename to Sources/F5TTS/F5TTS.swift index 830f10e..049fa79 100644 --- a/Sources/F5TTS/CFM.swift +++ b/Sources/F5TTS/F5TTS.swift @@ -3,59 +3,15 @@ import Hub import MLX import MLXNN import MLXRandom +import Vocos -// utilities - -func lensToMask(t: MLXArray, length: Int? = nil) -> MLXArray { - let maxLength = length ?? t.max(keepDims: false).item(Int.self) - let seq = MLXArray(0.. MLXArray { - let ndim = t.ndim - - guard let seqLen = t.shape.last, length > seqLen else { - return t[0..., .ellipsis] - } - - let paddingValue = MLXArray(value ?? 0.0) - - let padded: MLXArray - switch ndim { - case 1: - padded = MLX.padded(t, widths: [.init((0, length - seqLen))], value: paddingValue) - case 2: - padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen))], value: paddingValue) - case 3: - padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen)), .init((0, 0))], value: paddingValue) - default: - fatalError("Unsupported padding dims: \(ndim)") - } - - return padded[0..., .ellipsis] -} - -func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray { - let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0 - let t = MLX.stacked(t, axis: 0) - return padToLength(t, length: maxLen, value: paddingValue) -} - -func listStrToIdx(_ text: [String], vocabCharMap: [String: Int], paddingValue: Int = -1) -> MLXArray { - let listIdxTensors = text.map { str in str.map { char in vocabCharMap[String(char), default: 0] }} - let mlxArrays = listIdxTensors.map { MLXArray($0) } - let paddedText = padSequence(mlxArrays, paddingValue: Float(paddingValue)) - return paddedText.asType(.int32) -} - -// MARK: - +// MARK: - F5TTS public class F5TTS: Module { enum F5TTSError: Error { case unableToLoadModel + case unableToLoadReferenceAudio + case unableToDetermineDuration } public let melSpec: MelSpec @@ -100,20 +56,20 @@ public class F5TTS: Module { return MLX.stacked(ys, axis: 0) } - public func sample( + private func sample( cond: MLXArray, text: [String], duration: Any, lens: MLXArray? = nil, steps: Int = 32, - cfgStrength: Float = 2.0, - swayCoef: Float? = -1.0, + cfgStrength: Double = 2.0, + swayCoef: Double? = -1.0, seed: Int? = nil, maxDuration: Int = 4096, vocoder: ((MLXArray) -> MLXArray)? = nil, noRefAudio: Bool = false, editMask: MLXArray? = nil, - progressHandler: ((Float) -> Void)? = nil + progressHandler: ((Double) -> Void)? = nil ) -> (MLXArray, MLXArray) { MLX.eval(self.parameters()) @@ -183,7 +139,7 @@ public class F5TTS: Module { mask: mask ) - progressHandler?(t) + progressHandler?(Double(t)) return pred + (pred - nullPred) * cfgStrength } @@ -218,13 +174,82 @@ public class F5TTS: Module { return (out, trajectory) } + + public func generate( + text: String, + referenceAudioURL: URL? = nil, + referenceAudioText: String? = nil, + duration: TimeInterval? = nil, + cfg: Double = 2.0, + sway: Double = -1.0, + speed: Double = 1.0, + seed: Int? = nil, + progressHandler: ((Double) -> Void)? = nil + ) async throws -> MLXArray { + print("Loading Vocos model...") + let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx") + + // load the reference audio + text + + var audio: MLXArray + let referenceText: String + + if let referenceAudioURL { + audio = try F5TTS.loadAudioArray(url: referenceAudioURL) + referenceText = referenceAudioText ?? "" + } else { + let refAudioAndCaption = try F5TTS.referenceAudio() + (audio, referenceText) = refAudioAndCaption + } + + let refAudioDuration = Double(audio.shape[0]) / Double(F5TTS.sampleRate) + print("Using reference audio with duration: \(refAudioDuration)") + + // use a heuristic to determine the duration if not provided + + var generatedDuration = duration + if generatedDuration == nil { + generatedDuration = F5TTS.estimatedDuration(refAudio: audio, refText: referenceText, text: text) + } + + guard let generatedDuration else { + throw F5TTSError.unableToDetermineDuration + } + print("Using generated duration: \(generatedDuration)") + + // generate the audio + + let normalizedAudio = F5TTS.normalizeAudio(audio: audio) + + let processedText = referenceText + " " + text + let frameDuration = Int((refAudioDuration + generatedDuration) * F5TTS.framesPerSecond) + print("Generating \(generatedDuration) seconds (\(frameDuration) total frames) of audio...") + + let (outputAudio, _) = self.sample( + cond: normalizedAudio.expandedDimensions(axis: 0), + text: [processedText], + duration: frameDuration, + steps: 32, + cfgStrength: cfg, + swayCoef: sway, + seed: seed, + vocoder: vocos.decode + ) { progress in + print("Generation progress: \(progress)") + } + + let generatedAudio = outputAudio[audio.shape[0]...] + return generatedAudio + } } -// MARK: - +// MARK: - Pretrained Models public extension F5TTS { - static func fromPretrained(repoId: String) async throws -> F5TTS { - let modelDirectoryURL = try await Hub.snapshot(from: repoId, matching: ["*.safetensors", "*.txt"]) + static func fromPretrained(repoId: String, downloadProgress: ((Progress) -> Void)? = nil) async throws -> F5TTS { + let modelDirectoryURL = try await Hub.snapshot(from: repoId, matching: ["*.safetensors", "*.txt"]) { progress in + downloadProgress?(progress) + } return try self.fromPretrained(modelDirectoryURL: modelDirectoryURL) } @@ -273,3 +298,97 @@ public extension F5TTS { return f5tts } } + +// MARK: - Utilities + +public extension F5TTS { + static var sampleRate: Int = 24000 + static var hopLength: Int = 256 + static var framesPerSecond: Double = .init(sampleRate) / Double(hopLength) + + static func loadAudioArray(url: URL) throws -> MLXArray { + return try AudioUtilities.loadAudioFile(url: url) + } + + static func referenceAudio() throws -> (MLXArray, String) { + guard let url = Bundle.module.url(forResource: "test_en_1_ref_short", withExtension: "wav") else { + throw F5TTSError.unableToLoadReferenceAudio + } + + return try ( + self.loadAudioArray(url: url), + "Some call me nature, others call me mother nature." + ) + } + + static func normalizeAudio(audio: MLXArray, targetRMS: Double = 0.1) -> MLXArray { + let rms = Double(audio.square().mean().sqrt().item(Float.self)) + if rms < targetRMS { + return audio * targetRMS / rms + } + return audio + } + + static func estimatedDuration(refAudio: MLXArray, refText: String, text: String, speed: Double = 1.0) -> TimeInterval { + let refDurationInFrames = refAudio.shape[0] / self.hopLength + let pausePunctuation = "。,、;:?!" + let refTextLength = refText.utf8.count + 3 * pausePunctuation.utf8.count + let genTextLength = text.utf8.count + 3 * pausePunctuation.utf8.count + + let refAudioToTextRatio = Double(refDurationInFrames) / Double(refTextLength) + let textLength = Double(genTextLength) / speed + let estimatedDurationInFrames = Int(refAudioToTextRatio * textLength) + + let estimatedDuration = TimeInterval(estimatedDurationInFrames) / Self.framesPerSecond + print("Using duration of \(estimatedDuration) seconds (\(estimatedDurationInFrames) frames) for generated speech.") + + return estimatedDuration + } +} + +// MLX utilities + +func lensToMask(t: MLXArray, length: Int? = nil) -> MLXArray { + let maxLength = length ?? t.max(keepDims: false).item(Int.self) + let seq = MLXArray(0.. MLXArray { + let ndim = t.ndim + + guard let seqLen = t.shape.last, length > seqLen else { + return t[0..., .ellipsis] + } + + let paddingValue = MLXArray(value ?? 0.0) + + let padded: MLXArray + switch ndim { + case 1: + padded = MLX.padded(t, widths: [.init((0, length - seqLen))], value: paddingValue) + case 2: + padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen))], value: paddingValue) + case 3: + padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen)), .init((0, 0))], value: paddingValue) + default: + fatalError("Unsupported padding dims: \(ndim)") + } + + return padded[0..., .ellipsis] +} + +func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray { + let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0 + let t = MLX.stacked(t, axis: 0) + return padToLength(t, length: maxLen, value: paddingValue) +} + +func listStrToIdx(_ text: [String], vocabCharMap: [String: Int], paddingValue: Int = -1) -> MLXArray { + let listIdxTensors = text.map { str in str.map { char in vocabCharMap[String(char), default: 0] }} + let mlxArrays = listIdxTensors.map { MLXArray($0) } + let paddedText = padSequence(mlxArrays, paddingValue: Float(paddingValue)) + return paddedText.asType(.int32) +} diff --git a/Sources/F5TTS/mel_filters.npy b/Sources/F5TTS/Resources/mel_filters.npy similarity index 100% rename from Sources/F5TTS/mel_filters.npy rename to Sources/F5TTS/Resources/mel_filters.npy diff --git a/Sources/F5TTS/test_en_1_ref_short.wav b/Sources/F5TTS/Resources/test_en_1_ref_short.wav similarity index 100% rename from Sources/F5TTS/test_en_1_ref_short.wav rename to Sources/F5TTS/Resources/test_en_1_ref_short.wav diff --git a/Sources/f5-tts-generate/GenerateCommand.swift b/Sources/f5-tts-generate/GenerateCommand.swift index 59c902c..59f583c 100644 --- a/Sources/f5-tts-generate/GenerateCommand.swift +++ b/Sources/f5-tts-generate/GenerateCommand.swift @@ -1,7 +1,7 @@ import ArgumentParser -import MLX import F5TTS import Foundation +import MLX import Vocos @main @@ -25,87 +25,40 @@ struct GenerateAudio: AsyncParsableCommand { var outputPath: String = "output.wav" @Option(name: .long, help: "Strength of classifier free guidance") - var cfg: Float = 2.0 + var cfg: Double = 2.0 @Option(name: .long, help: "Coefficient for sway sampling") - var sway: Float = -1.0 + var sway: Double = -1.0 @Option(name: .long, help: "Speed factor for the duration heuristic") - var speed: Float = 1.0 + var speed: Double = 1.0 @Option(name: .long, help: "Seed for noise generation") var seed: Int? func run() async throws { - let sampleRate = 24_000 - let hopLength = 256 - let framesPerSec = Double(sampleRate) / Double(hopLength) - let targetRMS: Float = 0.1 - - let f5tts = try await F5TTS.fromPretrained(repoId: model) - let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx") - - var audio: MLXArray - let referenceText: String - - if let refPath = refAudioPath { - audio = try AudioUtilities.loadAudioFile(url: URL(filePath: refPath)) - referenceText = refAudioText ?? "Some call me nature, others call me mother nature." - } else if let refURL = Bundle.main.url(forResource: "test_en_1_ref_short", withExtension: "wav") { - audio = try AudioUtilities.loadAudioFile(url: refURL) - referenceText = "Some call me nature, others call me mother nature." - } else { - fatalError("No reference audio file specified.") - } - - let rms = audio.square().mean().sqrt().item(Float.self) - if rms < targetRMS { - audio = audio * targetRMS / rms + print("Loading F5-TTS model...") + let f5tts = try await F5TTS.fromPretrained(repoId: model) { progress in + print(" -- \(progress.completedUnitCount) of \(progress.totalUnitCount)") } - // use a heuristic to determine the duration if not provided - let refAudioDuration = Double(audio.shape[0]) / framesPerSec - var generatedDuration = duration - - if generatedDuration == nil { - let refAudioLength = audio.shape[0] / hopLength - let pausePunctuation = "。,、;:?!" - let refTextLength = referenceText.utf8.count + 3 * pausePunctuation.utf8.count - let genTextLength = text.utf8.count + 3 * pausePunctuation.utf8.count - - let durationInFrames = refAudioLength + Int((Double(refAudioLength) / Double(refTextLength)) * (Double(genTextLength) / Double(speed))) - let estimatedDuration = Double(durationInFrames - refAudioLength) / framesPerSec - - print("Using duration of \(estimatedDuration) seconds for generated speech.") - generatedDuration = estimatedDuration - } - - guard let generatedDuration else { - fatalError("Unable to determine duration.") - } - - let processedText = referenceText + " " + text - let frameDuration = Int((refAudioDuration + generatedDuration) * framesPerSec) - print("Generating \(frameDuration) frames of audio...") - let startTime = Date() - let (outputAudio, _) = f5tts.sample( - cond: audio.expandedDimensions(axis: 0), - text: [processedText], - duration: frameDuration, - steps: 32, - cfgStrength: cfg, - swayCoef: sway, - seed: seed, - vocoder: vocos.decode + let generatedAudio = try await f5tts.generate( + text: text, + referenceAudioURL: refAudioPath != nil ? URL(filePath: refAudioPath!) : nil, + referenceAudioText: refAudioText, + duration: duration, + cfg: cfg, + sway: sway, + speed: speed, + seed: seed ) - let generatedAudio = outputAudio[audio.shape[0]...] - let elapsedTime = Date().timeIntervalSince(startTime) - print("Generated \(Double(generatedAudio.count) / Double(sampleRate)) seconds of audio in \(elapsedTime) seconds.") + print("Generated \(Double(generatedAudio.shape[0]) / Double(F5TTS.sampleRate)) seconds of audio in \(elapsedTime) seconds.") try AudioUtilities.saveAudioFile(url: URL(filePath: outputPath), samples: generatedAudio) + print("Saved audio to: \(outputPath)") } }