diff --git a/.github/workflows/gen_whl_to_pypi.yml b/.github/workflows/gen_whl_to_pypi.yml new file mode 100644 index 0000000..1549a8a --- /dev/null +++ b/.github/workflows/gen_whl_to_pypi.yml @@ -0,0 +1,76 @@ +name: Push rapid_paraformer to pypi + +on: + push: + branches: [ main ] + paths: + - 'python/rapid_paraformer/**' + - 'python/docs/doc_whl.md' + - 'python/setup.py' + - '.github/workflows/gen_whl_to_pypi.yml' + +# env: +# RESOURCES_URL: https://github.com/RapidAI/RapidLatexOCR/releases/download/v0.0.0/models.zip + +jobs: + # UnitTesting: + # runs-on: ubuntu-latest + # steps: + # - name: Pull latest code + # uses: actions/checkout@v3 + + # - name: Set up Python 3.7 + # uses: actions/setup-python@v4 + # with: + # python-version: '3.7' + # architecture: 'x64' + + # - name: Display Python version + # run: python -c "import sys; print(sys.version)" + + # - name: Download models + # run: | + # wget $RESOURCES_URL + # ZIP_NAME=${RESOURCES_URL##*/} + # DIR_NAME=${ZIP_NAME%.*} + # unzip $ZIP_NAME + + # - name: Unit testings with rapid_latex_ocr + # run: | + # pip install -r requirements.txt + # pip install pytest + # pytest tests/test*.py + + GenerateWHL_PushPyPi: + needs: UnitTesting + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.7 + uses: actions/setup-python@v4 + with: + python-version: '3.7' + architecture: 'x64' + + - name: Run setup.py + run: | + cd python + pip install -r requirements.txt + python -m pip install --upgrade pip + pip install wheel get_pypi_latest_version + python setup.py bdist_wheel ${{ github.event.head_commit.message }} + + # - name: Publish distribution 📦 to Test PyPI + # uses: pypa/gh-action-pypi-publish@v1.5.0 + # with: + # password: ${{ secrets.TEST_PYPI_API_TOKEN }} + # repository_url: https://test.pypi.org/legacy/ + # packages_dir: dist/ + + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@v1.5.0 + with: + password: ${{ secrets.PYPI_API_TOKEN }} + packages_dir: python/dist/ diff --git a/.gitignore b/.gitignore index f3342c9..5f3ee39 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ *.onnx - -*.pth +*.json # Created by .ignore support plugin (hsz.mobi) ### Python template @@ -21,6 +20,8 @@ dist/ downloads/ eggs/ .eggs/ +lib/ +lib64/ parts/ sdist/ var/ @@ -138,6 +139,7 @@ dmypy.json .vs .vscode .idea +/models #models diff --git a/README.md b/README.md index ce66dc5..3f9dcba 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,9 @@ A([wav]) --RapidVad--> B([各个小段的音频]) --RapidASR--> C([识别的文 #### 📣更新日志
详情 - +- 2023-08-21 v2.0.4 update: + - 添加whl包支持 + - 更新文档 - 2023-02-25 - 添加C++版本推理,使用onnxruntime引擎,预/后处理代码来自: [FastASR](https://github.com/chenkui164/FastASR) - 2023-02-14 v2.0.3 update: diff --git a/python/.pre-commit-config.yaml b/python/.pre-commit-config.yaml new file mode 100644 index 0000000..5c227d6 --- /dev/null +++ b/python/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: +- repo: https://gitee.com/SWHL/autoflake + rev: v2.1.1 + hooks: + - id: autoflake + args: + [ + "--recursive", + "--in-place", + "--remove-all-unused-imports", + "--remove-unused-variable", + "--ignore-init-module-imports", + ] + files: \.py$ +- repo: https://gitee.com/SWHL/black + rev: 23.1.0 + hooks: + - id: black + files: \.py$ \ No newline at end of file diff --git a/python/README.md b/python/README.md index 0d4a311..25aff77 100644 --- a/python/README.md +++ b/python/README.md @@ -1,8 +1,12 @@ -## Rapid ASR +## rapid_paraformer

- + + + PyPI + SemVer2.0 +

- 模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) @@ -13,61 +17,44 @@ - [ ] 整合vad + asr + pun三个模型,打造可部署使用的方案 #### 使用步骤 -1. 安装环境 - ```bash - pip install -r requirements.txt - ``` -2. 下载模型 - - 由于模型太大(823.8M),上传到仓库不容易下载, - - (推荐)自助转换:基于modescope下的notebook环境,可一键转换,详情戳:[快速体验](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) - - 打开notebook → Cell中输入`!python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true`, 执行即可。 - - 提供百度云下载连接:[asr_paraformerv2.onnx](https://pan.baidu.com/s/1-nEf2eUpkzlcRqiYEwub2A?pwd=dcr3)(模型MD5: `9ca331381a470bc4458cc6c0b0b165de`) - - 模型下载之后,放在`resources/models`目录下即可,最终目录结构如下: - ```text - . - ├── demo.py - ├── rapid_paraformer - │   ├── __init__.py - │   ├── kaldifeat - │   ├── __pycache__ - │   ├── rapid_paraformer.py - │   └── utils.py - ├── README.md - ├── requirements.txt - ├── resources - │   ├── config.yaml - │   └── models - │   ├── am.mvn - │   ├── asr_paraformerv2.onnx # 放在这里 - │   └── token_list.pkl - ├── test_onnx.py - ├── tests - │   ├── __pycache__ - │   └── test_infer.py - └── test_wavs - ├── 0478_00017.wav - └── asr_example_zh.wav +1. Install + 1. 安装`rapid_paraformer` + ```bash + pip install rapid_paraformer ``` - -3. 运行demo + 2. 下载**resources.zip** ([Google Drive](https://drive.google.com/drive/folders/1RVQtMe0eB_k6G5TJlmXwPELx4VtF2oCw?usp=sharing) | [百度网盘](https://pan.baidu.com/s/1zf8Ta6QxFHY3Z75fHNYKrQ?pwd=6ekq)) + ```bash + resources + ├── [ 700] config.yaml + └── [4.0K] models + ├── [ 11K] am.mvn + ├── [824M] asr_paraformerv2.onnx + └── [ 50K] token_list.pkl + ``` + 3. **asr_paraformerv2.onnx**文件可基于modescope下的notebook环境自助转换: + 1. 打开[快速体验](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) + 2. 打开notebook → Cell中输入以下命令, 执行即可。 + ```python + !python -m funasr.export.export_model --model-name 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' --export-dir "./export" + ``` + +2. 使用 ```python from rapid_paraformer import RapidParaformer + config_path = "resources/config.yaml" - config_path = 'resources/config.yaml' paraformer = RapidParaformer(config_path) - # 输入:支持Union[str, np.ndarray, List[str]] 三种方式传入 - # 输出: List[asr_res] wav_path = [ - 'test_wavs/0478_00017.wav', + "test_wavs/0478_00017.wav", + "test_wavs/asr_example_zh.wav", ] result = paraformer(wav_path) print(result) ``` -4. 查看结果 - ```text - ['呃说不配合就不配合的好以上的话呢我们摘取八九十三条因为这三条的话呢比较典型啊一些数字比较明确尤其是时间那么我们要投资者就是了解这一点啊不要轻信这个市场可以快速回来啊这些配市公司啊后期又利好了可 - 以快速快速攻能包括像前一段时间啊有些媒体在二三月份的时候'] - ``` +3. 查看结果 + ```text + ['y', '欢迎大家来体验达摩院推出的语音识别模型'] + ``` diff --git a/python/demo.py b/python/demo.py index c5170e6..1a9798a 100644 --- a/python/demo.py +++ b/python/demo.py @@ -3,21 +3,15 @@ # @Contact: liekkaskono@163.com from rapid_paraformer import RapidParaformer - -config_path = 'resources/config.yaml' +config_path = "resources/config.yaml" paraformer = RapidParaformer(config_path) wav_path = [ - 'test_wavs/0478_00017.wav', - 'test_wavs/asr_example_zh.wav', - 'test_wavs/0478_00017.wav', - 'test_wavs/asr_example_zh.wav', - 'test_wavs/0478_00017.wav', - 'test_wavs/asr_example_zh.wav', + "test_wavs/0478_00017.wav", + "test_wavs/asr_example_zh.wav", ] print(wav_path) -# wav_path = 'test_wavs/0478_00017.wav' result = paraformer(wav_path) print(result) diff --git a/python/docs/doc_whl.md b/python/docs/doc_whl.md new file mode 100644 index 0000000..395bb35 --- /dev/null +++ b/python/docs/doc_whl.md @@ -0,0 +1,45 @@ +## rapid_paraformer + +

+ + + + PyPI + SemVer2.0 + +

+ + +### Use +1. Install + 1. Install the `rapid_paraformer` + ```bash + pip install rapid_paraformer + ``` + 2. Download the **resources.zip** ([Google Drive](https://drive.google.com/drive/folders/1RVQtMe0eB_k6G5TJlmXwPELx4VtF2oCw?usp=sharing) | [Baidu NetDisk](https://pan.baidu.com/s/1zf8Ta6QxFHY3Z75fHNYKrQ?pwd=6ekq)) + ```bash + resources + ├── [ 700] config.yaml + └── [4.0K] models + ├── [ 11K] am.mvn + ├── [824M] asr_paraformerv2.onnx + └── [ 50K] token_list.pkl + ``` +2. Use + ```python + from rapid_paraformer import RapidParaformer + + config_path = "resources/config.yaml" + + paraformer = RapidParaformer(config_path) + + wav_path = [ + "test_wavs/0478_00017.wav", + "test_wavs/asr_example_zh.wav", + ] + + result = paraformer(wav_path) + print(result) + ``` + +### See details for [RapidASR](https://github.com/RapidAI/RapidASR). diff --git a/python/rapid_paraformer/rapid_paraformer.py b/python/rapid_paraformer/rapid_paraformer.py index 34b3692..e7a37d1 100644 --- a/python/rapid_paraformer/rapid_paraformer.py +++ b/python/rapid_paraformer/rapid_paraformer.py @@ -1,35 +1,41 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import traceback from pathlib import Path -from typing import List, Union, Tuple +from typing import List, Tuple, Union import librosa import numpy as np -from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, - OrtInferSession, TokenIDConverter, WavFrontend, get_logger, - read_yaml) +from .utils import ( + CharTokenizer, + Hypothesis, + ONNXRuntimeError, + OrtInferSession, + TokenIDConverter, + WavFrontend, + get_logger, + read_yaml, +) logging = get_logger() -class RapidParaformer(): +class RapidParaformer: def __init__(self, config_path: Union[str, Path]) -> None: if not Path(config_path).exists(): - raise FileNotFoundError(f'{config_path} does not exist.') + raise FileNotFoundError(f"{config_path} does not exist.") config = read_yaml(config_path) - self.converter = TokenIDConverter(**config['TokenIDConverter']) - self.tokenizer = CharTokenizer(**config['CharTokenizer']) + self.converter = TokenIDConverter(**config["TokenIDConverter"]) + self.tokenizer = CharTokenizer(**config["CharTokenizer"]) self.frontend = WavFrontend( - cmvn_file=config['WavFrontend']['cmvn_file'], - **config['WavFrontend']['frontend_conf'] + cmvn_file=config["WavFrontend"]["cmvn_file"], + **config["WavFrontend"]["frontend_conf"], ) - self.ort_infer = OrtInferSession(config['Model']) - self.batch_size = config['Model']['batch_size'] + self.ort_infer = OrtInferSession(config["Model"]) + self.batch_size = config["Model"]["batch_size"] def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List: waveform_list = self.load_data(wav_content) @@ -52,8 +58,7 @@ def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List: asr_res.extend(preds) return asr_res - def load_data(self, - wav_content: Union[str, np.ndarray, List[str]]) -> List: + def load_data(self, wav_content: Union[str, np.ndarray, List[str]]) -> List: def load_wav(path: str) -> np.ndarray: waveform, _ = librosa.load(path, sr=None) return waveform[None, ...] @@ -67,12 +72,11 @@ def load_wav(path: str) -> np.ndarray: if isinstance(wav_content, list): return [load_wav(path) for path in wav_content] - raise TypeError( - f'The type of {wav_content} is not in [str, np.ndarray, list]') + raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") - def extract_feat(self, - waveform_list: List[np.ndarray] - ) -> Tuple[np.ndarray, np.ndarray]: + def extract_feat( + self, waveform_list: List[np.ndarray] + ) -> Tuple[np.ndarray, np.ndarray]: feats, feats_len = [], [] for waveform in waveform_list: speech, _ = self.frontend.fbank(waveform) @@ -88,24 +92,25 @@ def extract_feat(self, def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: pad_width = ((0, max_feat_len - cur_len), (0, 0)) - return np.pad(feat, pad_width, 'constant', constant_values=0) + return np.pad(feat, pad_width, "constant", constant_values=0) feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] feats = np.array(feat_res).astype(np.float32) return feats - def infer(self, feats: np.ndarray, - feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def infer( + self, feats: np.ndarray, feats_len: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: am_scores, token_nums = self.ort_infer([feats, feats_len]) return am_scores, token_nums def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: - return [self.decode_one(am_score, token_num) - for am_score, token_num in zip(am_scores, token_nums)] + return [ + self.decode_one(am_score, token_num) + for am_score, token_num in zip(am_scores, token_nums) + ] - def decode_one(self, - am_score: np.ndarray, - valid_token_num: int) -> List[str]: + def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]: yseq = am_score.argmax(axis=-1) score = am_score.max(axis=-1) score = np.sum(score, axis=-1) @@ -125,15 +130,15 @@ def decode_one(self, # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) text = self.tokenizer.tokens2text(token) - return text[:valid_token_num-1] + return text[: valid_token_num - 1] -if __name__ == '__main__': +if __name__ == "__main__": project_dir = Path(__file__).resolve().parent.parent - cfg_path = project_dir / 'resources' / 'config.yaml' + cfg_path = project_dir / "resources" / "config.yaml" paraformer = RapidParaformer(cfg_path) - wav_file = '0478_00017.wav' + wav_file = "0478_00017.wav" for i in range(1000): result = paraformer(wav_file) print(result) diff --git a/python/rapid_paraformer/utils.py b/python/rapid_paraformer/utils.py index 829e36d..2600526 100644 --- a/python/rapid_paraformer/utils.py +++ b/python/rapid_paraformer/utils.py @@ -4,13 +4,19 @@ import functools import logging import pickle +import warnings from pathlib import Path from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union import numpy as np import yaml -from onnxruntime import (GraphOptimizationLevel, InferenceSession, - SessionOptions, get_available_providers, get_device) +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) from typeguard import check_argument_types from .kaldifeat import compute_fbank_feats @@ -20,9 +26,12 @@ logger_initialized = {} -class TokenIDConverter(): - def __init__(self, token_path: Union[Path, str], - unk_symbol: str = "",): +class TokenIDConverter: + def __init__( + self, + token_path: Union[Path, str], + unk_symbol: str = "", + ): check_argument_types() self.token_list = self.load_token(token_path) @@ -31,23 +40,23 @@ def __init__(self, token_path: Union[Path, str], @staticmethod def load_token(file_path: Union[Path, str]) -> List: if not Path(file_path).exists(): - raise TokenIDConverterError(f'The {file_path} does not exist.') + raise TokenIDConverterError(f"The {file_path} does not exist.") - with open(str(file_path), 'rb') as f: + with open(str(file_path), "rb") as f: token_list = pickle.load(f) if len(token_list) != len(set(token_list)): - raise TokenIDConverterError('The Token exists duplicated symbol.') + raise TokenIDConverterError("The Token exists duplicated symbol.") return token_list def get_num_vocabulary_size(self) -> int: return len(self.token_list) - def ids2tokens(self, - integers: Union[np.ndarray, Iterable[int]]) -> List[str]: + def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: if isinstance(integers, np.ndarray) and integers.ndim != 1: raise TokenIDConverterError( - f"Must be 1 dim ndarray, but got {integers.ndim}") + f"Must be 1 dim ndarray, but got {integers.ndim}" + ) return [self.token_list[i] for i in integers] def tokens2ids(self, tokens: Iterable[str]) -> List[int]: @@ -60,7 +69,7 @@ def tokens2ids(self, tokens: Iterable[str]) -> List[int]: return [token2id.get(i, unk_id) for i in tokens] -class CharTokenizer(): +class CharTokenizer: def __init__( self, symbol_value: Union[Path, str, Iterable[str]] = None, @@ -96,7 +105,7 @@ def text2tokens(self, line: Union[str, list]) -> List[str]: if line.startswith(w): if not self.remove_non_linguistic_symbols: tokens.append(line[: len(w)]) - line = line[len(w):] + line = line[len(w) :] break else: t = line[0] @@ -119,23 +128,22 @@ def __repr__(self): ) -class WavFrontend(): - """Conventional frontend structure for ASR. - """ +class WavFrontend: + """Conventional frontend structure for ASR.""" def __init__( - self, - cmvn_file: str = None, - fs: int = 16000, - window: str = 'hamming', - n_mels: int = 80, - frame_length: int = 25, - frame_shift: int = 10, - filter_length_min: int = -1, - filter_length_max: float = -1, - lfr_m: int = 1, - lfr_n: int = 1, - dither: float = 1.0 + self, + cmvn_file: str = None, + fs: int = 16000, + window: str = "hamming", + n_mels: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + filter_length_min: int = -1, + filter_length_max: float = -1, + lfr_m: int = 1, + lfr_n: int = 1, + dither: float = 1.0, ) -> None: check_argument_types() @@ -154,19 +162,20 @@ def __init__( if self.cmvn_file: self.cmvn = self.load_cmvn() - def fbank(self, - input_content: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def fbank(self, input_content: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: waveform_len = input_content.shape[1] waveform = input_content[0][:waveform_len] waveform = waveform * (1 << 15) - mat = compute_fbank_feats(waveform, - num_mel_bins=self.n_mels, - frame_length=self.frame_length, - frame_shift=self.frame_shift, - dither=self.dither, - energy_floor=0.0, - window_type=self.window, - sample_frequency=self.fs) + mat = compute_fbank_feats( + waveform, + num_mel_bins=self.n_mels, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=self.dither, + energy_floor=0.0, + window_type=self.window, + sample_frequency=self.fs, + ) feat = mat.astype(np.float32) feat_len = np.array(mat.shape[0]).astype(np.int32) return feat, feat_len @@ -193,11 +202,12 @@ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray: for i in range(T_lfr): if lfr_m <= T - i * lfr_n: LFR_inputs.append( - (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1)) + (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1) + ) else: # process last LFR frame num_padding = lfr_m - (T - i * lfr_n) - frame = inputs[i * lfr_n:].reshape(-1) + frame = inputs[i * lfr_n :].reshape(-1) for _ in range(num_padding): frame = np.hstack((frame, inputs[-1])) @@ -215,24 +225,26 @@ def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray: inputs = (inputs + means) * vars return inputs - def load_cmvn(self,) -> np.ndarray: - with open(self.cmvn_file, 'r', encoding='utf-8') as f: + def load_cmvn( + self, + ) -> np.ndarray: + with open(self.cmvn_file, "r", encoding="utf-8") as f: lines = f.readlines() means_list = [] vars_list = [] for i in range(len(lines)): line_item = lines[i].split() - if line_item[0] == '': + if line_item[0] == "": line_item = lines[i + 1].split() - if line_item[0] == '': - add_shift_line = line_item[3:(len(line_item) - 1)] + if line_item[0] == "": + add_shift_line = line_item[3 : (len(line_item) - 1)] means_list = list(add_shift_line) continue - elif line_item[0] == '': + elif line_item[0] == "": line_item = lines[i + 1].split() - if line_item[0] == '': - rescale_line = line_item[3:(len(line_item) - 1)] + if line_item[0] == "": + rescale_line = line_item[3 : (len(line_item) - 1)] vars_list = list(rescale_line) continue @@ -247,8 +259,8 @@ class Hypothesis(NamedTuple): yseq: np.ndarray score: Union[float, np.ndarray] = 0 - scores: Dict[str, Union[float, np.ndarray]] = dict() - states: Dict[str, Any] = dict() + scores: Dict[str, Union[float, np.ndarray]] = {} + states: Dict[str, Any] = {} def asdict(self) -> dict: """Convert data to JSON-friendly dict.""" @@ -267,56 +279,64 @@ class ONNXRuntimeError(Exception): pass -class OrtInferSession(): +class OrtInferSession: def __init__(self, config): sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - cuda_ep = 'CUDAExecutionProvider' - cpu_ep = 'CPUExecutionProvider' + cuda_ep = "CUDAExecutionProvider" + cpu_ep = "CPUExecutionProvider" cpu_provider_options = { "arena_extend_strategy": "kSameAsRequested", } EP_list = [] - if config['use_cuda'] and get_device() == 'GPU' \ - and cuda_ep in get_available_providers(): + if ( + config["use_cuda"] + and get_device() == "GPU" + and cuda_ep in get_available_providers() + ): EP_list = [(cuda_ep, config[cuda_ep])] EP_list.append((cpu_ep, cpu_provider_options)) - config['model_path'] = config['model_path'] - self._verify_model(config['model_path']) - self.session = InferenceSession(config['model_path'], - sess_options=sess_opt, - providers=EP_list) - - if config['use_cuda'] and cuda_ep not in self.session.get_providers(): - warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' - 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' - 'you can check their relations from the offical web site: ' - 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html', - RuntimeWarning) - - def __call__(self, - input_content: List[np.ndarray]) -> np.ndarray: + config["model_path"] = config["model_path"] + self._verify_model(config["model_path"]) + self.session = InferenceSession( + config["model_path"], sess_options=sess_opt, providers=EP_list + ) + + if config["use_cuda"] and cuda_ep not in self.session.get_providers(): + warnings.warn( + f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" + "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " + "you can check their relations from the offical web site: " + "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", + RuntimeWarning, + ) + + def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: input_dict = dict(zip(self.get_input_names(), input_content)) try: return self.session.run(None, input_dict) except Exception as e: - raise ONNXRuntimeError('ONNXRuntime inference failed.') from e + raise ONNXRuntimeError("ONNXRuntime inference failed.") from e - def get_input_names(self, ): + def get_input_names( + self, + ): return [v.name for v in self.session.get_inputs()] - def get_output_names(self,): + def get_output_names( + self, + ): return [v.name for v in self.session.get_outputs()] - def get_character_list(self, key: str = 'character'): + def get_character_list(self, key: str = "character"): return self.meta_dict[key].splitlines() - def have_key(self, key: str = 'character') -> bool: + def have_key(self, key: str = "character") -> bool: self.meta_dict = self.session.get_modelmeta().custom_metadata_map if key in self.meta_dict.keys(): return True @@ -326,22 +346,22 @@ def have_key(self, key: str = 'character') -> bool: def _verify_model(model_path): model_path = Path(model_path) if not model_path.exists(): - raise FileNotFoundError(f'{model_path} does not exists.') + raise FileNotFoundError(f"{model_path} does not exists.") if not model_path.is_file(): - raise FileExistsError(f'{model_path} is not a file.') + raise FileExistsError(f"{model_path} is not a file.") def read_yaml(yaml_path: Union[str, Path]) -> Dict: if not Path(yaml_path).exists(): - raise FileExistsError(f'The {yaml_path} does not exist.') + raise FileExistsError(f"The {yaml_path} does not exist.") - with open(str(yaml_path), 'rb') as f: + with open(str(yaml_path), "rb") as f: data = yaml.load(f, Loader=yaml.Loader) return data @functools.lru_cache() -def get_logger(name='rapdi_paraformer'): +def get_logger(name="rapdi_paraformer"): """Initialize and get a logger by name. If the logger has not been initialized, this method will initialize the logger by adding one or two handlers, otherwise the initialized logger will @@ -361,8 +381,8 @@ def get_logger(name='rapdi_paraformer'): return logger formatter = logging.Formatter( - '[%(asctime)s] %(name)s %(levelname)s: %(message)s', - datefmt="%Y/%m/%d %H:%M:%S") + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S" + ) sh = logging.StreamHandler() sh.setFormatter(formatter) diff --git a/python/requirements.txt b/python/requirements.txt index f19d646..374fd93 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -2,4 +2,4 @@ librosa numpy onnxruntime scipy -typeguard>=2.13.3 +typeguard==2.13.3 diff --git a/python/resources/config.yaml b/python/resources/config.yaml index 83736a4..0e7621f 100644 --- a/python/resources/config.yaml +++ b/python/resources/config.yaml @@ -21,7 +21,7 @@ WavFrontend: dither: 0.0 Model: - model_path: resources/models/model.onnx + model_path: resources/models/asr_paraformerv2.onnx use_cuda: false CUDAExecutionProvider: device_id: 0 diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000..0d1aa31 --- /dev/null +++ b/python/setup.py @@ -0,0 +1,75 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import sys +from pathlib import Path +from typing import List + +import setuptools +from get_pypi_latest_version import GetPyPiLatestVersion + + +def read_txt(txt_path: str) -> List: + if not isinstance(txt_path, str): + txt_path = str(txt_path) + + with open(txt_path, "r", encoding="utf-8") as f: + data = list(map(lambda x: x.rstrip("\n"), f)) + return data + + +def get_readme() -> str: + root_dir = Path(__file__).resolve().parent + readme_path = str(root_dir / "docs" / "doc_whl.md") + with open(readme_path, "r", encoding="utf-8") as f: + readme = f.read() + return readme + + +MODULE_NAME = "rapid_paraformer" + +obtainer = GetPyPiLatestVersion() +try: + latest_version = obtainer(MODULE_NAME) +except ValueError: + latest_version = "0.0.1" + +VERSION_NUM = obtainer.version_add_one(latest_version) + +# 优先提取commit message中的语义化版本号,如无,则自动加1 +if len(sys.argv) > 2: + match_str = " ".join(sys.argv[2:]) + matched_versions = obtainer.extract_version(match_str) + if matched_versions: + VERSION_NUM = matched_versions +sys.argv = sys.argv[:2] + +setuptools.setup( + name=MODULE_NAME, + version=VERSION_NUM, + platforms="Any", + description="Tool of speech recognition.", + long_description=get_readme(), + long_description_content_type="text/markdown", + author="SWHL", + author_email="liekkaskono@163.com", + url="https://github.com/RapidAI/RapidASR", + license="Apache-2.0", + include_package_data=True, + install_requires=read_txt("requirements.txt"), + packages=[MODULE_NAME, f"{MODULE_NAME}/kaldifeat"], + package_data={"": ["*.md", "LICENSE"]}, + keywords=["asr,paraformer,wenet"], + classifiers=[ + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], + python_requires=">=3.6,<3.12", + entry_points={ + "console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"], + }, +)