Skip to content

Commit

Permalink
ort v2.0.0-rc.1ベースに切り替える
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Apr 28, 2024
1 parent b147977 commit fc968d1
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 60 deletions.
86 changes: 42 additions & 44 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ zip = "0.6.3"

[workspace.dependencies.voicevox-ort]
git = "https://github.com/qryxip/ort.git"
rev = "37af007322f0dd5a21e536ab3bcf727970f1283a"
rev = "59e94ac87732e9da3f83ebbd542a3062f3cf2264"

[workspace.dependencies.open_jtalk]
git = "https://github.com/VOICEVOX/open_jtalk-rs.git"
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub(crate) trait InferenceInputSignature: Send + 'static {
pub(crate) trait InputScalar: Sized {
const KIND: InputScalarKind;

// TODO: `Array`ではなく`ArrayView`を取ることができるかもしれない
fn push_tensor_to_ctx(
tensor: Array<Self, impl Dimension + 'static>,
visitor: &mut impl PushInputTensor,
Expand Down
21 changes: 6 additions & 15 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use ort::{
ExecutionProviderDispatch, GraphOptimizationLevel, IntoTensorElementType, TensorElementType,
ValueType,
};
use tracing::warn;

use crate::{devices::SupportedDevices, error::ErrorRepr};

Expand Down Expand Up @@ -56,17 +55,9 @@ impl InferenceRuntime for Onnxruntime {
// TODO: `InferenceRuntime::init`と`InitInferenceRuntimeError`を作る
build_ort_env_once().unwrap();

let cpu_num_threads = options.cpu_num_threads.try_into().unwrap_or_else(|_| {
warn!(
"`cpu_num_threads={}` is too large. Setting it to 32767",
options.cpu_num_threads,
);
i16::MAX
});

let builder = ort::Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(cpu_num_threads)?;
.with_intra_threads(options.cpu_num_threads.into())?;

let builder = if options.use_gpu && cfg!(feature = "directml") {
builder
Expand All @@ -84,7 +75,7 @@ impl InferenceRuntime for Onnxruntime {
};

let model = model()?;
let sess = builder.with_model_from_memory(&{ model })?;
let sess = builder.commit_from_memory(&{ model })?;

let input_param_infos = sess
.inputs
Expand Down Expand Up @@ -187,8 +178,8 @@ impl InferenceRuntime for Onnxruntime {

match ty {
TensorElementType::Float32 => {
let output = output.extract_tensor::<f32>()?;
Ok(OutputTensor::Float32(output.view().clone().into_owned()))
let output = output.try_extract_tensor::<f32>()?;
Ok(OutputTensor::Float32(output.into_owned()))
}
_ => bail!("unexpected output tensor element data type"),
}
Expand All @@ -205,7 +196,7 @@ fn build_ort_env_once() -> ort::Result<()> {

pub(crate) struct OnnxruntimeRunContext<'sess> {
sess: &'sess ort::Session,
inputs: Vec<ort::Value>,
inputs: Vec<ort::SessionInputValue<'static>>,
}

impl OnnxruntimeRunContext<'_> {
Expand All @@ -216,7 +207,7 @@ impl OnnxruntimeRunContext<'_> {
impl Dimension + 'static,
>,
) -> anyhow::Result<()> {
let input = input.try_into()?;
let input = ort::Value::from_array(input)?.into();
self.inputs.push(input);
Ok(())
}
Expand Down

0 comments on commit fc968d1

Please sign in to comment.