diff --git a/Cargo.lock b/Cargo.lock index 689b8e992..4bf93c98a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1170,6 +1170,7 @@ dependencies = [ "platforms", "rayon", "reqwest", + "rstest", "strum", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 0c8d7aca3..9058d33d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ fs-err = { version = "2.9.0", features = ["tokio"] } itertools = "0.10.5" once_cell = "1.18.0" regex = "1.10.0" +rstest = "0.15.0" serde = { version = "1.0.145", features = ["derive"] } serde_json = { version = "1.0.85", features = ["preserve_order"] } strum = { version = "0.24.1", features = ["derive"] } diff --git a/crates/download/Cargo.toml b/crates/download/Cargo.toml index c3063a9dc..ad80400ed 100644 --- a/crates/download/Cargo.toml +++ b/crates/download/Cargo.toml @@ -27,3 +27,6 @@ tracing.workspace = true tracing-subscriber.workspace = true url = "2.3.0" zip = "0.6.3" + +[dev-dependencies] +rstest.workspace = true diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 31452719d..03a426478 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + collections::HashSet, env, future::Future, io::{self, Cursor, Read}, @@ -26,7 +27,7 @@ use once_cell::sync::Lazy; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use strum::{Display, IntoStaticStr}; use tokio::task::{JoinError, JoinSet}; -use tracing::info; +use tracing::{info, warn}; use url::Url; use zip::ZipArchive; @@ -48,7 +49,20 @@ static OPEN_JTALK_DIC_URL: Lazy = Lazy::new(|| { #[derive(clap::Parser)] struct Args { - /// ダウンロードするライブラリを最小限にするように指定 + /// ダウンロード対象を限定する + #[arg( + long, + num_args(1..), + value_name("TARGET"), + conflicts_with_all(["exclude", "min"])) + ] + only: Vec, + + /// ダウンロード対象を除外する + #[arg(long, num_args(1..), value_name("TARGET"), conflicts_with("min"))] + exclude: Vec, + + /// `--only core`のエイリアス #[arg(long, conflicts_with("additional_libraries_version"))] min: bool, @@ -87,6 +101,14 @@ struct Args { additional_libraries_repo: RepoName, } +#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, Hash)] +enum DownloadTarget { + Core, + Models, + AdditionalLibraries, + Dict, +} + #[derive(Default, ValueEnum, Display, IntoStaticStr, Clone, Copy, PartialEq)] #[strum(serialize_all = "kebab-case")] enum Device { @@ -133,8 +155,9 @@ impl Os { } } -#[derive(parse_display::FromStr, Clone)] +#[derive(parse_display::FromStr, parse_display::Display, Clone)] #[from_str(regex = "(?[a-zA-Z0-9_]+)/(?[a-zA-Z0-9_]+)")] +#[display("{owner}/{repo}")] struct RepoName { owner: String, repo: String, @@ -145,6 +168,8 @@ async fn main() -> anyhow::Result<()> { setup_logger(); let Args { + only, + exclude, min, output, version, @@ -156,6 +181,57 @@ async fn main() -> anyhow::Result<()> { additional_libraries_repo, } = Args::parse(); + let targets: HashSet<_> = if !only.is_empty() { + assert!(exclude.is_empty() && !min); + only.into_iter().collect() + } else if !exclude.is_empty() { + assert!(!min); + DownloadTarget::value_variants() + .iter() + .copied() + .filter(|t| !exclude.contains(t)) + .collect() + } else if min { + [DownloadTarget::Core].into() + } else { + DownloadTarget::value_variants().iter().copied().collect() + }; + + if !(targets.contains(&DownloadTarget::Core) || targets.contains(&DownloadTarget::Models)) { + if version != "latest" { + warn!( + "`--version={version}`が指定されていますが、`core`も`models`もダウンロード対象から\ + 除外されています", + ); + } + if core_repo.to_string() != DEFAULT_CORE_REPO { + warn!( + "`--core-repo={core_repo}`が指定されていますが、`core`も`models`もダウンロード対象\ + から除外されています", + ); + } + } + if !targets.contains(&DownloadTarget::AdditionalLibraries) { + if additional_libraries_version != "latest" { + warn!( + "`--additional-libraries-version={additional_libraries_version}`が指定されています\ + が、`additional-libraries-version`はダウンロード対象から除外されています", + ); + } + if additional_libraries_repo.to_string() != DEFAULT_ADDITIONAL_LIBRARIES_REPO { + warn!( + "`--additional-libraries-repo={additional_libraries_repo}`が指定されていますが、\ + `additional-libraries-version`はダウンロード対象から除外されています", + ); + } + if device == Device::Cpu { + warn!( + "`--device`が指定されていない、もしくは`--device=cpu`が指定されていますが、\ + `additional-libraries-version`はダウンロード対象から除外されています", + ); + } + } + let octocrab = &octocrab()?; let core = find_gh_asset(octocrab, &core_repo, &version, |tag| { @@ -202,21 +278,23 @@ async fn main() -> anyhow::Result<()> { let mut tasks = JoinSet::new(); - tasks.spawn(download_and_extract_from_gh( - core, - Stripping::FirstDir, - &output, - &progresses, - )?); - - if !min { + if targets.contains(&DownloadTarget::Core) { + tasks.spawn(download_and_extract_from_gh( + core, + Stripping::FirstDir, + &output, + &progresses, + )?); + } + if targets.contains(&DownloadTarget::Models) { tasks.spawn(download_and_extract_from_gh( model, Stripping::FirstDir, &output.join("model"), &progresses, )?); - + } + if targets.contains(&DownloadTarget::AdditionalLibraries) { if let Some(additional_libraries) = additional_libraries { tasks.spawn(download_and_extract_from_gh( additional_libraries, @@ -225,7 +303,8 @@ async fn main() -> anyhow::Result<()> { &progresses, )?); } - + } + if targets.contains(&DownloadTarget::Dict) { tasks.spawn(download_and_extract_from_url( &OPEN_JTALK_DIC_URL, Stripping::None, @@ -584,3 +663,20 @@ enum Stripping { None, FirstDir, } + +#[cfg(test)] +mod tests { + use clap::Parser as _; + use rstest::rstest; + + use super::Args; + + #[rstest] + #[case(&["", "--only", "core", "--exclude", "models"])] + #[case(&["", "--min", "--only", "core"])] + #[case(&["", "--min", "--exclude", "core"])] + fn it_denies_conflicting_options(#[case] args: &[&str]) { + let result = Args::try_parse_from(args).map(|_| ()).map_err(|e| e.kind()); + assert_eq!(Err(clap::error::ErrorKind::ArgumentConflict), result); + } +} diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 7627eca4c..5e46fe653 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -46,7 +46,7 @@ rev = "a16714ce16dec76fd0e3041a7acfa484921db3b5" flate2 = "1.0.24" heck = "0.4.0" pretty_assertions = "1.3.0" -rstest = "0.15.0" +rstest.workspace = true tar = "0.4.38" test_util.workspace = true