Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: compatible_engineにsafety requirementとアサートを入れる #869

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions crates/voicevox_core_c_api/src/compatible_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,31 @@ pub extern "C" fn supported_devices() -> *const c_char {
});
}

// SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[unsafe(no_mangle)]
pub extern "C" fn yukarin_s_forward(
/// # Safety
///
/// - `phoneme_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `speaker_id`はRustの`&[i64; 1]`として解釈できなければならない。
/// - `output`はRustの`&mut [f32; length as usize]`として解釈できなければならない。
#[unsafe(no_mangle)] // SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
pub unsafe extern "C" fn yukarin_s_forward(
length: i64,
phoneme_list: *mut i64,
speaker_id: *mut i64,
output: *mut f32,
) -> bool {
init_logger_once();
assert_aligned(phoneme_list);
assert_aligned(speaker_id);
assert_aligned(output);
let synthesizer = &*lock_synthesizer();
let result = ensure_initialized!(synthesizer).predict_duration(
// SAFETY: The safety contract must be upheld by the caller.
unsafe { std::slice::from_raw_parts_mut(phoneme_list, length as usize) },
StyleId::new(unsafe { *speaker_id as u32 }),
);
match result {
Ok(output_vec) => {
// SAFETY: The safety contract must be upheld by the caller.
let output_slice = unsafe { std::slice::from_raw_parts_mut(output, length as usize) };
output_slice.clone_from_slice(&output_vec);
true
Expand All @@ -248,9 +257,18 @@ pub extern "C" fn yukarin_s_forward(
}
}

// SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[unsafe(no_mangle)]
pub extern "C" fn yukarin_sa_forward(
/// # Safety
///
/// - `vowel_phoneme_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `consonant_phoneme_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `start_accent_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `end_accent_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `start_accent_phrase_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `end_accent_phrase_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `speaker_id`はRustの`&[i64; 1]`として解釈できなければならない。
/// - `output`はRustの`&mut [f32; length as usize]`として解釈できなければならない。
#[unsafe(no_mangle)] // SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
pub unsafe extern "C" fn yukarin_sa_forward(
length: i64,
vowel_phoneme_list: *mut i64,
consonant_phoneme_list: *mut i64,
Expand All @@ -262,9 +280,18 @@ pub extern "C" fn yukarin_sa_forward(
output: *mut f32,
) -> bool {
init_logger_once();
assert_aligned(vowel_phoneme_list);
assert_aligned(consonant_phoneme_list);
assert_aligned(start_accent_list);
assert_aligned(end_accent_list);
assert_aligned(start_accent_phrase_list);
assert_aligned(end_accent_phrase_list);
assert_aligned(speaker_id);
assert_aligned(output);
let synthesizer = &*lock_synthesizer();
let result = ensure_initialized!(synthesizer).predict_intonation(
length as usize,
// SAFETY: The safety contract must be upheld by the caller.
unsafe { std::slice::from_raw_parts(vowel_phoneme_list, length as usize) },
unsafe { std::slice::from_raw_parts(consonant_phoneme_list, length as usize) },
unsafe { std::slice::from_raw_parts(start_accent_list, length as usize) },
Expand All @@ -275,6 +302,7 @@ pub extern "C" fn yukarin_sa_forward(
);
match result {
Ok(output_vec) => {
// SAFETY: The safety contract must be upheld by the caller.
let output_slice = unsafe { std::slice::from_raw_parts_mut(output, length as usize) };
output_slice.clone_from_slice(&output_vec);
true
Expand All @@ -286,9 +314,14 @@ pub extern "C" fn yukarin_sa_forward(
}
}

// SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[unsafe(no_mangle)]
pub extern "C" fn decode_forward(
/// # Safety
///
/// - `f0`はRustの`&[f32; length as usize]`として解釈できなければならない。
/// - `phoneme`はRustの`&[f32; phoneme_size * length as usize]`として解釈できなければならない。
/// - `speaker_id`はRustの`&[i64; 1]`として解釈できなければならない。
/// - `output`はRustの`&mut [f32; length as usize * 256]`として解釈できなければならない。
#[unsafe(no_mangle)] // SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
pub unsafe extern "C" fn decode_forward(
length: i64,
phoneme_size: i64,
f0: *mut f32,
Expand All @@ -297,18 +330,24 @@ pub extern "C" fn decode_forward(
output: *mut f32,
) -> bool {
init_logger_once();
assert_aligned(f0);
assert_aligned(phoneme);
assert_aligned(speaker_id);
assert_aligned(output);
let length = length as usize;
let phoneme_size = phoneme_size as usize;
let synthesizer = &*lock_synthesizer();
let result = ensure_initialized!(synthesizer).decode(
length,
phoneme_size,
// SAFETY: The safety contract must be upheld by the caller.
unsafe { std::slice::from_raw_parts(f0, length) },
unsafe { std::slice::from_raw_parts(phoneme, phoneme_size * length) },
StyleId::new(unsafe { *speaker_id as u32 }),
);
match result {
Ok(output_vec) => {
// SAFETY: The safety contract must be upheld by the caller.
let output_slice = unsafe { std::slice::from_raw_parts_mut(output, length * 256) };
output_slice.clone_from_slice(&output_vec);
true
Expand All @@ -319,3 +358,11 @@ pub extern "C" fn decode_forward(
}
}
}

#[track_caller]
fn assert_aligned(ptr: *mut impl Sized) {
assert!(
ptr.is_aligned(),
"all of the pointers passed to this library **must** be aligned",
);
}
Loading