Skip to content

Commit

Permalink
Eval after each ODE step for accurate progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 21, 2024
1 parent 4a7cba3 commit 1dbfd84
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 0 additions & 2 deletions Sources/F5TTS/DiT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ class InputEmbedding: Module {

let combined = MLX.concatenated([x, cond, textEmbed], axis: -1)
var output = proj(combined)
output.eval()
output = conv_pos_embed(output) + output
output.eval()
return output
}
}
Expand Down
15 changes: 9 additions & 6 deletions Sources/F5TTS/F5TTS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ public class F5TTS: Module {

progressHandler?(Double(t))

return pred + (pred - nullPred) * cfgStrength
let output = pred + (pred - nullPred) * cfgStrength
output.eval()

return output
}

// noise input
Expand All @@ -165,11 +168,10 @@ public class F5TTS: Module {
let trajectory = self.odeint(fun: fn, y0: y0Padded, t: t)
let sampled = trajectory[-1]
var out = MLX.where(condMask, cond, sampled)

if let vocoder = vocoder {
out = vocoder(out)
}

out.eval()

return (out, trajectory)
Expand Down Expand Up @@ -239,6 +241,8 @@ public class F5TTS: Module {
}

let generatedAudio = outputAudio[audio.shape[0]...]

print("Got generated audio of shape: \(generatedAudio.shape)")
return generatedAudio
}
}
Expand Down Expand Up @@ -331,9 +335,8 @@ public extension F5TTS {

static func estimatedDuration(refAudio: MLXArray, refText: String, text: String, speed: Double = 1.0) -> TimeInterval {
let refDurationInFrames = refAudio.shape[0] / self.hopLength
let pausePunctuation = "。,、;:?!"
let refTextLength = refText.utf8.count + 3 * pausePunctuation.utf8.count
let genTextLength = text.utf8.count + 3 * pausePunctuation.utf8.count
let refTextLength = refText.utf8.count
let genTextLength = text.utf8.count

let refAudioToTextRatio = Double(refDurationInFrames) / Double(refTextLength)
let textLength = Double(genTextLength) / speed
Expand Down
5 changes: 1 addition & 4 deletions Sources/F5TTS/Modules.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ func precomputeFreqsCis(dim: Int, end: Int, theta: Float = 10000.0, thetaRescale
let freqsCos = outerFreqs.cos()
let freqsSin = outerFreqs.sin()

let output = MLX.concatenated([freqsCos, freqsSin], axis: -1)
output.eval()

return output
return MLX.concatenated([freqsCos, freqsSin], axis: -1)
}

func getPosEmbedIndices(start: MLXArray, length: Int, maxPos: Int, scale: Float = 1.0) -> MLXArray {
Expand Down

0 comments on commit 1dbfd84

Please sign in to comment.