forked from thewh1teagle/sherpa-rs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonline_recognizer.rs
102 lines (90 loc) · 3.12 KB
/
online_recognizer.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{FromSample, Sample, SampleRate};
use sherpa_rs::online::paraformer::Paraformer;
use sherpa_rs::online::stream::recognizer::{Recognizer, Search, Stream};
use sherpa_rs::online::stream::OnlineStream;
use std::fs::File;
use std::io::BufWriter;
use std::path::Path;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{Arc, Mutex};
// websocket -> VAD -> hotword? -> ASR -> ?
// websocket -> hotword? -> ASR -> ?
fn main() -> Result<(), anyhow::Error> {
let host = cpal::default_host();
// Set up the input device and stream with the default input config.
let device = host
.default_input_device()
.expect("failed to find input device");
println!("Input device: {}", device.name()?);
let config = cpal::StreamConfig {
channels: 1,
sample_rate: SampleRate(16000),
buffer_size: cpal::BufferSize::Default,
};
let encoder = Path::new(
"/home/lemonxh/下载/sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx",
);
let decoder = Path::new(
"/home/lemonxh/下载/sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx",
);
let tokens =
Path::new("/home/lemonxh/下载/sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt");
let tr = Paraformer::new(encoder, decoder);
let online_rec = Recognizer::from_paraformer(
tr,
Some("cpu"),
tokens,
Search::Greedy,
false,
None,
None,
None,
);
let rec: Stream = Stream::from_recognizer(online_rec);
println!("Begin recording...");
let (recorder, receiver) = std::sync::mpsc::channel();
let err_fn = move |err| {
eprintln!("an error occurred on stream: {}", err);
};
let stream = device.build_input_stream_raw(
&config,
cpal::SampleFormat::F32,
move |x, _| {
let x = x.as_slice().unwrap().to_vec();
recorder.clone().send(x).unwrap();
},
err_fn,
None,
)?;
stream.play()?;
println!("Creating recognizer...");
recognizer(rec, receiver);
}
fn recognizer(mut online_rec: Stream, receiver: Receiver<Vec<f32>>) -> ! {
let mut last_text = String::new();
let mut segment_index = 0;
println!("current segment: {}", segment_index);
loop {
let samples = receiver.recv().unwrap();
online_rec.accept_waveform(16000, samples);
while online_rec.is_ready() {
online_rec.decode_stream();
}
let result = online_rec.get_result();
if !result.is_empty() && last_text != result {
last_text = result.clone();
println!("\t{}", result.to_lowercase());
}
if online_rec.is_endpoint() {
if !result.is_empty() {
let result = online_rec.get_result();
last_text = result.clone();
println!("final result:{}", result.to_lowercase());
segment_index += 1;
println!("current segment: {}", segment_index);
}
online_rec.reset();
}
}
}