Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kokoros onnx support #1634

Open
Areopagitics opened this issue Feb 2, 2025 · 0 comments
Open

Kokoros onnx support #1634

Areopagitics opened this issue Feb 2, 2025 · 0 comments

Comments

@Areopagitics
Copy link

I tried applied the tract model with the help of some ai, and this is what I got (at least it builds partly - see errors below).

use hound;
use ndarray::Array;
use tract_onnx::prelude::*;
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, Read};
use std::path::Path;
use std::process::Command;
const SUPPORTED_LANGUAGES: [&str; 7] = [
    "en-us", // English
    "en-gb", // English (British)
    "es",    // Spanish
    "fr-fr", // French
    "ja",    // Japanese
    "ko",    // Korean
    "cmn",   // Mandarin Chinese
];
const MAX_PHONEME_LENGTH: usize = 510;
const SAMPLE_RATE: usize = 24000;
#[derive(Debug)]
pub struct EspeakConfig {
    pub lib_path: Option<String>,
    pub data_path: Option<String>,
}
#[derive(Debug)]
pub struct KoKoroConfig {
    pub model_path: String,
    pub voices_path: String,
    pub espeak_config: Option<EspeakConfig>,
}
impl KoKoroConfig {
    pub fn new(model_path: &str, voices_path: &str, espeak_config: Option<EspeakConfig>) -> Self {
        KoKoroConfig {
            model_path: model_path.to_string(),
            voices_path: voices_path.to_string(),
            espeak_config,
        }
    }
    pub fn validate(&self) -> Result<(), String> {
        if !Path::new(&self.voices_path).exists() {
            let error_msg = format!(
                "Voices file not found at {}. You can download the voices file using the following command:\nwget https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.bin",
                self.voices_path
            );
            return Err(error_msg);
        }
        if !Path::new(&self.model_path).exists() {
            let error_msg = format!(
                "Model file not found at {}. You can download the model file using the following command:\nwget https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/kokoro-v0_19.onnx",
                self.model_path
            );
            return Err(error_msg);
        }
        Ok(())
    }
}
pub fn get_vocab() -> HashMap<char, usize> {
    let pad = '$';
    let punctuation = ";:,.!?¡¿—…\"«»“ ” ";
    let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ";
    let mut symbols: Vec<char> = vec![pad];
    symbols.extend(punctuation.chars());
    symbols.extend(letters.chars());
    symbols.extend(letters_ipa.chars());
    let mut vocab_map = HashMap::new();
    for (index, symbol) in symbols.iter().enumerate() {
        vocab_map.insert(*symbol, index);
    }
    vocab_map
}
pub struct Tokenizer {
    espeak_config: EspeakConfig,
}
impl Tokenizer {
    pub fn new(espeak_config: EspeakConfig) -> Self {
        Tokenizer { espeak_config }
    }
    fn phonemize(&self, text: &str, lang: &str) -> Result<String, Box<dyn std::error::Error>> {
        // Run eSpeak and capture the output
        let output = Command::new("espeak-ng")
            .args(&["-q", "-x", text, "--lang", lang]) // Arguments to eSpeak
            .output()?;
        // Check for errors
        if !output.status.success() {
            return Err(Box::new(io::Error::new(io::ErrorKind::Other, "eSpeak failed to execute")));
        }
        // Convert the output to a String
        let phonemes = String::from_utf8_lossy(&output.stdout).to_string();
        Ok(phonemes.trim().to_string())
    }
    pub fn tokenize(&self, text: &str, lang: &str) -> Result<Vec<String>, Box<dyn std::error::Error>> {
        let phonemes = self.phonemize(text, lang)?;
        if phonemes.len() > MAX_PHONEME_LENGTH {
            return Err(Box::new(io::Error::new(io::ErrorKind::InvalidInput, format!(
                "Text is too long, must be less than {} phonemes",
                MAX_PHONEME_LENGTH
            ))));
        }
        // Return a list of phonemes
        Ok(phonemes.split_whitespace().map(|s| s.to_string()).collect())
    }
}
pub struct Kokoro {
    model: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
    voices: HashMap<String, Vec<f32>>, // Store voices as vectors (or more complex types as needed)
    vocab: HashMap<char, usize>, // Runtime vocab
    tokenizer: Tokenizer,
}
impl Kokoro {
    pub fn new(model_path: &str, voices_path: &str, espeak_config: EspeakConfig) -> Result<Self, Box<dyn std::error::Error>> {
        let model = tract_onnx::onnx()
            .model_for_path(model_path)?
            .into_optimized()?
            .into_runnable()?;
        let voices = read_voices(voices_path)?;
        let vocab = get_vocab(); // Get the vocabulary at runtime
        let tokenizer = Tokenizer::new(espeak_config);
        Ok(Kokoro { model, voices, vocab, tokenizer })
    }
    fn get_voice_style(&self, name: &str) -> &[f32] {
        self.voices.get(name).expect("Voice not found")
    }
    fn create_audio(&self, tokens: &[String], voice: &[f32], speed: f32) -> Result<Tensor, Box<dyn std::error::Error>> {
        assert!(tokens.len() <= MAX_PHONEME_LENGTH, "Too many phonemes!");
        // Prepare token input for ONNX model
        let mut token_input = Array::<u8, _>::zeros((1, tokens.len() + 2)); // Use u8 for input
        for (i, token) in tokens.iter().enumerate() {
            token_input[[0, i + 1]] = *self.vocab.get(&token.chars().next().unwrap()).unwrap_or(&0) as u8; // Ensure input is u8
        }
        // Convert the input array to a Tensor
        let input_tensor: Tensor = token_input.into_tensor();
        // Create the input tensor vector needed for ONNX model
        let input_tensors: TVec<TValue> = tvec!(input_tensor.into());
        let result = self.model.run(input_tensors)?;
        // Convert the result back to a tensor
        let tensor: Tensor = result[0].to_owned().into_tensor();
        Ok(tensor)
    }
    pub fn create(&self, text: &str, voice: &str, speed: f32, lang: &str) -> Result<(Vec<f32>, usize), Box<dyn std::error::Error>> {
        // Language validation
        if !SUPPORTED_LANGUAGES.contains(&lang) {
            return Err(Box::new(io::Error::new(io::ErrorKind::InvalidInput, format!(
                "Language must be one of: {:?}. Got: {}", SUPPORTED_LANGUAGES, lang
            ))));
        }
        let tokens = self.tokenizer.tokenize(text, lang)?; // Use the Tokenizer for phonemization
        let voice_style = self.get_voice_style(voice);
        let audio_tensor = self.create_audio(&tokens, voice_style, speed)?;
        let audio_samples: Vec<f32> = audio_tensor.to_array_view::<f32>()?.iter().copied().collect();
        Ok((audio_samples, SAMPLE_RATE))
    }
}
// Function to read voices from a binary file
fn read_voices(file_path: &str) -> io::Result<HashMap<String, Vec<f32>>> {
    let mut file = File::open(file_path)?;
    let mut buffer = Vec::new();
    file.read_to_end(&mut buffer)?;
    // Example parsing logic (you will need a proper way to interpret your binary file)
    let mut voices = HashMap::new();
    // Implement voice parsing logic to fill the voices HashMap here (example placeholder)
    voices.insert("af_heart".to_string(), vec![0.0; 256]); // Placeholder data
    Ok(voices)
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
    let model_path = "kokoro-v1.0.fp16.onnx"; // Update path as necessary
    let voices_path = "voices.bin"; // Update path as necessary
    let espeak_config = EspeakConfig {
        lib_path: None, // Optionally set paths if needed
        data_path: None,
    };
    
    // Create the KoKoro synthesizer
    let kokoro = Kokoro::new(model_path, voices_path, espeak_config)?;
    
    // Input Text Configuration
    let text = "This is an English phrase for synthesis."; // Example English input
    let lang = "en-us";
    // Generate audio sample
    let (audio_samples, sample_rate) = kokoro.create(text, "af_heart", 1.0, lang)?;
    // Write the generated audio to a WAV file
    let mut writer = hound::WavWriter::create("audio.wav", hound::WavSpec {
        channels: 1, // Mono
        sample_rate: sample_rate as u32,
        bits_per_sample: 16, // Change as necessary
        sample_format: hound::SampleFormat::Int,
    })?;
    
    for sample in audio_samples {
        let sample_i16 = (sample * i16::MAX as f32).round() as i16; // Convert f32 to i16
        writer.write_sample(sample_i16)?;
    }
    writer.finalize()?;
    
    Ok(())
}

I get the following error:

Error: Failed analyse for node #1034 "/encoder/text_encoder/cnn.0/cnn.0.0/Conv_quant" ConvHir

Caused by:
    0: Infering facts
    1: Applying rule inputs[0].datum_type == inputs[3].datum_type
    2: Impossible to unify U8 with I8.

Any help in the right direction would be appreciated. There is a also a regular fp32 version of the onnx model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant