diff --git a/crates/download/src/main.rs b/crates/download/src/main.rs index 31452719d..10401e8d4 100644 --- a/crates/download/src/main.rs +++ b/crates/download/src/main.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + collections::HashSet, env, future::Future, io::{self, Cursor, Read}, @@ -26,7 +27,7 @@ use once_cell::sync::Lazy; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use strum::{Display, IntoStaticStr}; use tokio::task::{JoinError, JoinSet}; -use tracing::info; +use tracing::{info, warn}; use url::Url; use zip::ZipArchive; @@ -48,7 +49,20 @@ static OPEN_JTALK_DIC_URL: Lazy = Lazy::new(|| { #[derive(clap::Parser)] struct Args { - /// ダウンロードするライブラリを最小限にするように指定 + /// ダウンロード対象を限定する + #[arg( + long, + num_args(1..), + value_name("TARGET"), + conflicts_with_all(["exclude", "min"])) + ] + only: Vec, + + /// ダウンロード対象を除外する + #[arg(long, num_args(1..), value_name("TARGET"), conflicts_with("min"))] + exclude: Vec, + + /// `--only core`のエイリアス #[arg(long, conflicts_with("additional_libraries_version"))] min: bool, @@ -65,7 +79,12 @@ struct Args { additional_libraries_version: String, /// ダウンロードするデバイスを指定する(cudaはlinuxのみ) - #[arg(value_enum, long, default_value(<&str>::from(Device::default())))] + #[arg( + value_enum, + long, + default_value(<&str>::from(Device::default())), + required_if_eq("only", "additional-libraries") + )] device: Device, /// ダウンロードするcpuのアーキテクチャを指定する @@ -87,6 +106,14 @@ struct Args { additional_libraries_repo: RepoName, } +#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +enum DownloadTarget { + Core, + Models, + AdditionalLibraries, + Dict, +} + #[derive(Default, ValueEnum, Display, IntoStaticStr, Clone, Copy, PartialEq)] #[strum(serialize_all = "kebab-case")] enum Device { @@ -133,8 +160,9 @@ impl Os { } } -#[derive(parse_display::FromStr, Clone)] +#[derive(parse_display::FromStr, parse_display::Display, Clone)] #[from_str(regex = "(?[a-zA-Z0-9_]+)/(?[a-zA-Z0-9_]+)")] +#[display("{owner}/{repo}")] struct RepoName { owner: String, repo: String, @@ -145,6 +173,8 @@ async fn main() -> anyhow::Result<()> { setup_logger(); let Args { + only, + exclude, min, output, version, @@ -156,6 +186,51 @@ async fn main() -> anyhow::Result<()> { additional_libraries_repo, } = Args::parse(); + let targets: HashSet<_> = if !only.is_empty() { + assert!(exclude.is_empty() && !min); + only.into_iter().collect() + } else if !exclude.is_empty() { + assert!(!min); + DownloadTarget::value_variants() + .iter() + .copied() + .filter(|t| !exclude.contains(t)) + .collect() + } else if min { + [DownloadTarget::Core].into() + } else { + DownloadTarget::value_variants().iter().copied().collect() + }; + + if !(targets.contains(&DownloadTarget::Core) || targets.contains(&DownloadTarget::Models)) { + if version != "latest" { + warn!( + "`--version={version}`が指定されていますが、`core`も`models`もダウンロード対象から\ + 除外されています", + ); + } + if core_repo.to_string() != DEFAULT_CORE_REPO { + warn!( + "`--core-repo={core_repo}`が指定されていますが、`core`も`models`もダウンロード対象\ + から除外されています", + ); + } + } + if !targets.contains(&DownloadTarget::AdditionalLibraries) { + if additional_libraries_version != "latest" { + warn!( + "`--additional-libraries-version={additional_libraries_version}`が指定されています\ + が、`additional-libraries-version`はダウンロード対象から除外されています", + ); + } + if additional_libraries_repo.to_string() != DEFAULT_ADDITIONAL_LIBRARIES_REPO { + warn!( + "`--additional-libraries-repo={additional_libraries_repo}`が指定されていますが、\ + `additional-libraries-version`はダウンロード対象から除外されています", + ); + } + } + let octocrab = &octocrab()?; let core = find_gh_asset(octocrab, &core_repo, &version, |tag| { @@ -202,21 +277,23 @@ async fn main() -> anyhow::Result<()> { let mut tasks = JoinSet::new(); - tasks.spawn(download_and_extract_from_gh( - core, - Stripping::FirstDir, - &output, - &progresses, - )?); - - if !min { + if targets.contains(&DownloadTarget::Core) { + tasks.spawn(download_and_extract_from_gh( + core, + Stripping::FirstDir, + &output, + &progresses, + )?); + } + if targets.contains(&DownloadTarget::Models) { tasks.spawn(download_and_extract_from_gh( model, Stripping::FirstDir, &output.join("model"), &progresses, )?); - + } + if targets.contains(&DownloadTarget::AdditionalLibraries) { if let Some(additional_libraries) = additional_libraries { tasks.spawn(download_and_extract_from_gh( additional_libraries, @@ -225,7 +302,8 @@ async fn main() -> anyhow::Result<()> { &progresses, )?); } - + } + if targets.contains(&DownloadTarget::Dict) { tasks.spawn(download_and_extract_from_url( &OPEN_JTALK_DIC_URL, Stripping::None,