Skip to content

Commit

Permalink
run transcribe tasks in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
mcdallas committed Nov 10, 2023
1 parent dc87941 commit e28a9b9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ edition = "2021"
[dependencies]
anyhow = "1.0.69"
clap = { version = "4.1.8", features = ["derive", "cargo"] }
futures = "0.3.29"
log = "0.4.20"
reqwest = { version = "0.11.14", features = ["json", "multipart", "stream"] }
serde = { version = "1.0.153", features = ["derive"] }
serde_json = "1.0.94"
Expand Down
37 changes: 32 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use clap::{Parser, crate_authors, crate_version, crate_description};
use std::path::PathBuf;
use tokio::{self, fs::File, io::AsyncWriteExt};
use tokio::{self, task};
use tempdir::TempDir;
use futures::future::join_all;
use log::debug;

pub mod gpt;
pub mod whisper;
Expand All @@ -26,16 +28,41 @@ async fn main() {
let tmp_dir = TempDir::new("audio").expect("Could not create temporary directory");
let segments = util::split_file(args.file.clone(), &tmp_dir).await.expect("Could not split file");

let mut transcribed = Vec::new();

let mut transcribe_tasks = vec![];

let client = whisper::WhisperClient::new(api_key.clone());
for segment in segments {
let segment = client.transcribe(segment).await.expect("Could not transcribe audio");
transcribed.push(segment);
let client = whisper::WhisperClient::new(api_key.clone());
let task = task::spawn(async move {
debug!("Transcribing segment {:?}", segment.index);
let res = client.transcribe(segment).await;

match res {
Ok(segment) => {
debug!("End segment {:?}", segment.index);
return segment
},
Err(e) => panic!("Error transcribing segment: {}", e),
}
});
transcribe_tasks.push(task);
}

let results: Vec<_> = join_all(transcribe_tasks).await.into_iter().collect();


tmp_dir.close().expect("Could not delete temporary directory");

let mut transcribed = Vec::new();

for result in results {
match result {
Ok(segment) => transcribed.push(segment),

Err(e) => eprintln!("Task failed: {}", e),
}
}

let transcript = transcribed.iter().map(|segment| {
segment.transcript.clone()
}).collect::<Vec<String>>().join("\n");
Expand Down

0 comments on commit e28a9b9

Please sign in to comment.