Skip to content

Commit

Permalink
ダウンローダーに--only <TARGET>...--exclude <TARGET>...を追加 (#647)
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip authored Oct 15, 2023
1 parent 935468e commit fc757a5
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fs-err = { version = "2.9.0", features = ["tokio"] }
itertools = "0.10.5"
once_cell = "1.18.0"
regex = "1.10.0"
rstest = "0.15.0"
serde = { version = "1.0.145", features = ["derive"] }
serde_json = { version = "1.0.85", features = ["preserve_order"] }
strum = { version = "0.24.1", features = ["derive"] }
Expand Down
3 changes: 3 additions & 0 deletions crates/download/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ tracing.workspace = true
tracing-subscriber.workspace = true
url = "2.3.0"
zip = "0.6.3"

[dev-dependencies]
rstest.workspace = true
122 changes: 109 additions & 13 deletions crates/download/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Cow,
collections::HashSet,
env,
future::Future,
io::{self, Cursor, Read},
Expand All @@ -26,7 +27,7 @@ use once_cell::sync::Lazy;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use strum::{Display, IntoStaticStr};
use tokio::task::{JoinError, JoinSet};
use tracing::info;
use tracing::{info, warn};
use url::Url;
use zip::ZipArchive;

Expand All @@ -48,7 +49,20 @@ static OPEN_JTALK_DIC_URL: Lazy<Url> = Lazy::new(|| {

#[derive(clap::Parser)]
struct Args {
/// ダウンロードするライブラリを最小限にするように指定
/// ダウンロード対象を限定する
#[arg(
long,
num_args(1..),
value_name("TARGET"),
conflicts_with_all(["exclude", "min"]))
]
only: Vec<DownloadTarget>,

/// ダウンロード対象を除外する
#[arg(long, num_args(1..), value_name("TARGET"), conflicts_with("min"))]
exclude: Vec<DownloadTarget>,

/// `--only core`のエイリアス
#[arg(long, conflicts_with("additional_libraries_version"))]
min: bool,

Expand Down Expand Up @@ -87,6 +101,14 @@ struct Args {
additional_libraries_repo: RepoName,
}

#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, Hash)]
enum DownloadTarget {
Core,
Models,
AdditionalLibraries,
Dict,
}

#[derive(Default, ValueEnum, Display, IntoStaticStr, Clone, Copy, PartialEq)]
#[strum(serialize_all = "kebab-case")]
enum Device {
Expand Down Expand Up @@ -133,8 +155,9 @@ impl Os {
}
}

#[derive(parse_display::FromStr, Clone)]
#[derive(parse_display::FromStr, parse_display::Display, Clone)]
#[from_str(regex = "(?<owner>[a-zA-Z0-9_]+)/(?<repo>[a-zA-Z0-9_]+)")]
#[display("{owner}/{repo}")]
struct RepoName {
owner: String,
repo: String,
Expand All @@ -145,6 +168,8 @@ async fn main() -> anyhow::Result<()> {
setup_logger();

let Args {
only,
exclude,
min,
output,
version,
Expand All @@ -156,6 +181,57 @@ async fn main() -> anyhow::Result<()> {
additional_libraries_repo,
} = Args::parse();

let targets: HashSet<_> = if !only.is_empty() {
assert!(exclude.is_empty() && !min);
only.into_iter().collect()
} else if !exclude.is_empty() {
assert!(!min);
DownloadTarget::value_variants()
.iter()
.copied()
.filter(|t| !exclude.contains(t))
.collect()
} else if min {
[DownloadTarget::Core].into()
} else {
DownloadTarget::value_variants().iter().copied().collect()
};

if !(targets.contains(&DownloadTarget::Core) || targets.contains(&DownloadTarget::Models)) {
if version != "latest" {
warn!(
"`--version={version}`が指定されていますが、`core`も`models`もダウンロード対象から\
除外されています",
);
}
if core_repo.to_string() != DEFAULT_CORE_REPO {
warn!(
"`--core-repo={core_repo}`が指定されていますが、`core`も`models`もダウンロード対象\
から除外されています",
);
}
}
if !targets.contains(&DownloadTarget::AdditionalLibraries) {
if additional_libraries_version != "latest" {
warn!(
"`--additional-libraries-version={additional_libraries_version}`が指定されています\
が、`additional-libraries-version`はダウンロード対象から除外されています",
);
}
if additional_libraries_repo.to_string() != DEFAULT_ADDITIONAL_LIBRARIES_REPO {
warn!(
"`--additional-libraries-repo={additional_libraries_repo}`が指定されていますが、\
`additional-libraries-version`はダウンロード対象から除外されています",
);
}
if device == Device::Cpu {
warn!(
"`--device`が指定されていない、もしくは`--device=cpu`が指定されていますが、\
`additional-libraries-version`はダウンロード対象から除外されています",
);
}
}

let octocrab = &octocrab()?;

let core = find_gh_asset(octocrab, &core_repo, &version, |tag| {
Expand Down Expand Up @@ -202,21 +278,23 @@ async fn main() -> anyhow::Result<()> {

let mut tasks = JoinSet::new();

tasks.spawn(download_and_extract_from_gh(
core,
Stripping::FirstDir,
&output,
&progresses,
)?);

if !min {
if targets.contains(&DownloadTarget::Core) {
tasks.spawn(download_and_extract_from_gh(
core,
Stripping::FirstDir,
&output,
&progresses,
)?);
}
if targets.contains(&DownloadTarget::Models) {
tasks.spawn(download_and_extract_from_gh(
model,
Stripping::FirstDir,
&output.join("model"),
&progresses,
)?);

}
if targets.contains(&DownloadTarget::AdditionalLibraries) {
if let Some(additional_libraries) = additional_libraries {
tasks.spawn(download_and_extract_from_gh(
additional_libraries,
Expand All @@ -225,7 +303,8 @@ async fn main() -> anyhow::Result<()> {
&progresses,
)?);
}

}
if targets.contains(&DownloadTarget::Dict) {
tasks.spawn(download_and_extract_from_url(
&OPEN_JTALK_DIC_URL,
Stripping::None,
Expand Down Expand Up @@ -584,3 +663,20 @@ enum Stripping {
None,
FirstDir,
}

#[cfg(test)]
mod tests {
use clap::Parser as _;
use rstest::rstest;

use super::Args;

#[rstest]
#[case(&["", "--only", "core", "--exclude", "models"])]
#[case(&["", "--min", "--only", "core"])]
#[case(&["", "--min", "--exclude", "core"])]
fn it_denies_conflicting_options(#[case] args: &[&str]) {
let result = Args::try_parse_from(args).map(|_| ()).map_err(|e| e.kind());
assert_eq!(Err(clap::error::ErrorKind::ArgumentConflict), result);
}
}
2 changes: 1 addition & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ rev = "a16714ce16dec76fd0e3041a7acfa484921db3b5"
flate2 = "1.0.24"
heck = "0.4.0"
pretty_assertions = "1.3.0"
rstest = "0.15.0"
rstest.workspace = true
tar = "0.4.38"
test_util.workspace = true

Expand Down

0 comments on commit fc757a5

Please sign in to comment.