From a3705e735d379cd2b10372f31b9cfe3d1b1aff25 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sun, 15 Oct 2023 19:08:06 +0900 Subject: [PATCH 1/6] =?UTF-8?q?=E3=83=80=E3=82=A6=E3=83=B3=E3=83=AD?= =?UTF-8?q?=E3=83=BC=E3=83=80=E3=83=BC=E3=81=AB`--only=20...`?= =?UTF-8?q?=E3=81=A8`--exclude=20...`=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/download/src/main.rs | 106 +++++++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 14 deletions(-) diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 31452719d..10401e8d4 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + collections::HashSet, env, future::Future, io::{self, Cursor, Read}, @@ -26,7 +27,7 @@ use once_cell::sync::Lazy; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use strum::{Display, IntoStaticStr}; use tokio::task::{JoinError, JoinSet}; -use tracing::info; +use tracing::{info, warn}; use url::Url; use zip::ZipArchive; @@ -48,7 +49,20 @@ static OPEN_JTALK_DIC_URL: Lazy = Lazy::new(|| { #[derive(clap::Parser)] struct Args { - /// ダウンロードするライブラリを最小限にするように指定 + /// ダウンロード対象を限定する + #[arg( + long, + num_args(1..), + value_name("TARGET"), + conflicts_with_all(["exclude", "min"])) + ] + only: Vec, + + /// ダウンロード対象を除外する + #[arg(long, num_args(1..), value_name("TARGET"), conflicts_with("min"))] + exclude: Vec, + + /// `--only core`のエイリアス #[arg(long, conflicts_with("additional_libraries_version"))] min: bool, @@ -65,7 +79,12 @@ struct Args { additional_libraries_version: String, /// ダウンロードするデバイスを指定する(cudaはlinuxのみ) - #[arg(value_enum, long, default_value(<&str>::from(Device::default())))] + #[arg( + value_enum, + long, + default_value(<&str>::from(Device::default())), + required_if_eq("only", "additional-libraries") + )] device: Device, /// ダウンロードするcpuのアーキテクチャを指定する @@ -87,6 +106,14 @@ struct Args { additional_libraries_repo: RepoName, } +#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +enum DownloadTarget { + Core, + Models, + AdditionalLibraries, + Dict, +} + #[derive(Default, ValueEnum, Display, IntoStaticStr, Clone, Copy, PartialEq)] #[strum(serialize_all = "kebab-case")] enum Device { @@ -133,8 +160,9 @@ impl Os { } } -#[derive(parse_display::FromStr, Clone)] +#[derive(parse_display::FromStr, parse_display::Display, Clone)] #[from_str(regex = "(?[a-zA-Z0-9_]+)/(?[a-zA-Z0-9_]+)")] +#[display("{owner}/{repo}")] struct RepoName { owner: String, repo: String, @@ -145,6 +173,8 @@ async fn main() -> anyhow::Result<()> { setup_logger(); let Args { + only, + exclude, min, output, version, @@ -156,6 +186,51 @@ async fn main() -> anyhow::Result<()> { additional_libraries_repo, } = Args::parse(); + let targets: HashSet<_> = if !only.is_empty() { + assert!(exclude.is_empty() && !min); + only.into_iter().collect() + } else if !exclude.is_empty() { + assert!(!min); + DownloadTarget::value_variants() + .iter() + .copied() + .filter(|t| !exclude.contains(t)) + .collect() + } else if min { + [DownloadTarget::Core].into() + } else { + DownloadTarget::value_variants().iter().copied().collect() + }; + + if !(targets.contains(&DownloadTarget::Core) || targets.contains(&DownloadTarget::Models)) { + if version != "latest" { + warn!( + "`--version={version}`が指定されていますが、`core`も`models`もダウンロード対象から\ + 除外されています", + ); + } + if core_repo.to_string() != DEFAULT_CORE_REPO { + warn!( + "`--core-repo={core_repo}`が指定されていますが、`core`も`models`もダウンロード対象\ + から除外されています", + ); + } + } + if !targets.contains(&DownloadTarget::AdditionalLibraries) { + if additional_libraries_version != "latest" { + warn!( + "`--additional-libraries-version={additional_libraries_version}`が指定されています\ + が、`additional-libraries-version`はダウンロード対象から除外されています", + ); + } + if additional_libraries_repo.to_string() != DEFAULT_ADDITIONAL_LIBRARIES_REPO { + warn!( + "`--additional-libraries-repo={additional_libraries_repo}`が指定されていますが、\ + `additional-libraries-version`はダウンロード対象から除外されています", + ); + } + } + let octocrab = &octocrab()?; let core = find_gh_asset(octocrab, &core_repo, &version, |tag| { @@ -202,21 +277,23 @@ async fn main() -> anyhow::Result<()> { let mut tasks = JoinSet::new(); - tasks.spawn(download_and_extract_from_gh( - core, - Stripping::FirstDir, - &output, - &progresses, - )?); - - if !min { + if targets.contains(&DownloadTarget::Core) { + tasks.spawn(download_and_extract_from_gh( + core, + Stripping::FirstDir, + &output, + &progresses, + )?); + } + if targets.contains(&DownloadTarget::Models) { tasks.spawn(download_and_extract_from_gh( model, Stripping::FirstDir, &output.join("model"), &progresses, )?); - + } + if targets.contains(&DownloadTarget::AdditionalLibraries) { if let Some(additional_libraries) = additional_libraries { tasks.spawn(download_and_extract_from_gh( additional_libraries, @@ -225,7 +302,8 @@ async fn main() -> anyhow::Result<()> { &progresses, )?); } - + } + if targets.contains(&DownloadTarget::Dict) { tasks.spawn(download_and_extract_from_url( &OPEN_JTALK_DIC_URL, Stripping::None, From 7808c1c26a1578fc3c30250ce2740d807a88e5ce Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sun, 15 Oct 2023 22:42:27 +0900 Subject: [PATCH 2/6] =?UTF-8?q?=E4=B8=8D=E8=A6=81=E3=81=AA`derive`?= =?UTF-8?q?=E3=82=92=E7=9C=81=E3=81=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/download/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 10401e8d4..6cdd6012f 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -106,7 +106,7 @@ struct Args { additional_libraries_repo: RepoName, } -#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, Hash)] enum DownloadTarget { Core, Models, From 396c036f4b8fcbe91b4a9e6f0380d796a421f415 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sun, 15 Oct 2023 22:45:14 +0900 Subject: [PATCH 3/6] =?UTF-8?q?=E5=8D=98=E4=BD=93=E3=83=86=E3=82=B9?= =?UTF-8?q?=E3=83=88=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 1 + Cargo.toml | 1 + crates/download/Cargo.toml | 3 +++ crates/download/src/main.rs | 27 +++++++++++++++++++++++++++ crates/voicevox_core/Cargo.toml | 2 +- 5 files changed, 33 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 689b8e992..4bf93c98a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1170,6 +1170,7 @@ dependencies = [ "platforms", "rayon", "reqwest", + "rstest", "strum", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 0c8d7aca3..9058d33d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ fs-err = { version = "2.9.0", features = ["tokio"] } itertools = "0.10.5" once_cell = "1.18.0" regex = "1.10.0" +rstest = "0.15.0" serde = { version = "1.0.145", features = ["derive"] } serde_json = { version = "1.0.85", features = ["preserve_order"] } strum = { version = "0.24.1", features = ["derive"] } diff --git a/crates/download/Cargo.toml b/crates/download/Cargo.toml index c3063a9dc..ad80400ed 100644 --- a/crates/download/Cargo.toml +++ b/crates/download/Cargo.toml @@ -27,3 +27,6 @@ tracing.workspace = true tracing-subscriber.workspace = true url = "2.3.0" zip = "0.6.3" + +[dev-dependencies] +rstest.workspace = true diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 6cdd6012f..828629d1b 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -662,3 +662,30 @@ enum Stripping { None, FirstDir, } + +#[cfg(test)] +mod tests { + use clap::Parser as _; + use rstest::rstest; + + use super::Args; + + #[rstest] + #[case(&["", "--only", "core", "--exclude", "models"])] + #[case(&["", "--min", "--only", "core"])] + #[case(&["", "--min", "--exclude", "core"])] + fn it_denies_conflicting_options(#[case] args: &[&str]) { + let result = parse(args); + assert_eq!(Err(clap::error::ErrorKind::ArgumentConflict), result); + } + + #[test] + fn it_denies_only_option_without_device_option() { + let result = parse(&["", "--only", "additional-libraries"]); + assert_eq!(Err(clap::error::ErrorKind::MissingRequiredArgument), result); + } + + fn parse(args: &[&str]) -> Result<(), clap::error::ErrorKind> { + Args::try_parse_from(args).map(|_| ()).map_err(|e| e.kind()) + } +} diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 7627eca4c..5e46fe653 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -46,7 +46,7 @@ rev = "a16714ce16dec76fd0e3041a7acfa484921db3b5" flate2 = "1.0.24" heck = "0.4.0" pretty_assertions = "1.3.0" -rstest = "0.15.0" +rstest.workspace = true tar = "0.4.38" test_util.workspace = true From 1682bb2b308cf29c326c759bf3590b5f0aa7d60f Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 16 Oct 2023 02:47:48 +0900 Subject: [PATCH 4/6] =?UTF-8?q?=E3=83=86=E3=82=B9=E3=83=88=E5=90=8D?= =?UTF-8?q?=E3=82=92=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/download/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 828629d1b..9e4ff4e10 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -680,7 +680,7 @@ mod tests { } #[test] - fn it_denies_only_option_without_device_option() { + fn it_denies_only_additional_libraries_option_without_device_option() { let result = parse(&["", "--only", "additional-libraries"]); assert_eq!(Err(clap::error::ErrorKind::MissingRequiredArgument), result); } From 2f941b579507c1fa07852bcdaa4a20badca0480d Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 16 Oct 2023 03:03:06 +0900 Subject: [PATCH 5/6] =?UTF-8?q?=E3=80=8C`--device`=E6=8A=9C=E3=81=8D?= =?UTF-8?q?=E3=81=AE`--only=20additional-libraries`=E3=80=8D=E3=82=92warni?= =?UTF-8?q?ng=E3=81=AB=E9=99=8D=E6=A0=BC=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/download/src/main.rs | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 9e4ff4e10..49b9f67f3 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -79,12 +79,7 @@ struct Args { additional_libraries_version: String, /// ダウンロードするデバイスを指定する(cudaはlinuxのみ) - #[arg( - value_enum, - long, - default_value(<&str>::from(Device::default())), - required_if_eq("only", "additional-libraries") - )] + #[arg(value_enum, long, default_value(<&str>::from(Device::default())))] device: Device, /// ダウンロードするcpuのアーキテクチャを指定する @@ -229,6 +224,12 @@ async fn main() -> anyhow::Result<()> { `additional-libraries-version`はダウンロード対象から除外されています", ); } + if device != Device::Cpu { + warn!( + "`--device={device}`が指定されていますが、`additional-libraries-version`は\ + ダウンロード対象から除外されています", + ); + } } let octocrab = &octocrab()?; @@ -675,17 +676,7 @@ mod tests { #[case(&["", "--min", "--only", "core"])] #[case(&["", "--min", "--exclude", "core"])] fn it_denies_conflicting_options(#[case] args: &[&str]) { - let result = parse(args); + let result = Args::try_parse_from(args).map(|_| ()).map_err(|e| e.kind()); assert_eq!(Err(clap::error::ErrorKind::ArgumentConflict), result); } - - #[test] - fn it_denies_only_additional_libraries_option_without_device_option() { - let result = parse(&["", "--only", "additional-libraries"]); - assert_eq!(Err(clap::error::ErrorKind::MissingRequiredArgument), result); - } - - fn parse(args: &[&str]) -> Result<(), clap::error::ErrorKind> { - Args::try_parse_from(args).map(|_| ()).map_err(|e| e.kind()) - } } From e7c64fc54f900a7d0d38f5cbe37fe1b56a69c156 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 16 Oct 2023 03:07:01 +0900 Subject: [PATCH 6/6] =?UTF-8?q?=E6=9D=A1=E4=BB=B6=E3=82=92=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/download/src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 49b9f67f3..03a426478 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -224,10 +224,10 @@ async fn main() -> anyhow::Result<()> { `additional-libraries-version`はダウンロード対象から除外されています", ); } - if device != Device::Cpu { + if device == Device::Cpu { warn!( - "`--device={device}`が指定されていますが、`additional-libraries-version`は\ - ダウンロード対象から除外されています", + "`--device`が指定されていない、もしくは`--device=cpu`が指定されていますが、\ + `additional-libraries-version`はダウンロード対象から除外されています", ); } }