Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ダウンローダーに--only <TARGET>...--exclude <TARGET>...を追加 #647

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ fs-err = { version = "2.9.0", features = ["tokio"] }
itertools = "0.10.5"
once_cell = "1.18.0"
regex = "1.10.0"
rstest = "0.15.0"
serde = { version = "1.0.145", features = ["derive"] }
serde_json = { version = "1.0.85", features = ["preserve_order"] }
strum = { version = "0.24.1", features = ["derive"] }
Expand Down
3 changes: 3 additions & 0 deletions crates/download/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ tracing.workspace = true
tracing-subscriber.workspace = true
url = "2.3.0"
zip = "0.6.3"

[dev-dependencies]
rstest.workspace = true
133 changes: 119 additions & 14 deletions crates/download/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Cow,
collections::HashSet,
env,
future::Future,
io::{self, Cursor, Read},
Expand All @@ -26,7 +27,7 @@ use once_cell::sync::Lazy;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use strum::{Display, IntoStaticStr};
use tokio::task::{JoinError, JoinSet};
use tracing::info;
use tracing::{info, warn};
use url::Url;
use zip::ZipArchive;

Expand All @@ -48,7 +49,20 @@ static OPEN_JTALK_DIC_URL: Lazy<Url> = Lazy::new(|| {

#[derive(clap::Parser)]
struct Args {
/// ダウンロードするライブラリを最小限にするように指定
/// ダウンロード対象を限定する
#[arg(
long,
num_args(1..),
value_name("TARGET"),
conflicts_with_all(["exclude", "min"]))
]
only: Vec<DownloadTarget>,

/// ダウンロード対象を除外する
#[arg(long, num_args(1..), value_name("TARGET"), conflicts_with("min"))]
exclude: Vec<DownloadTarget>,

/// `--only core`のエイリアス
#[arg(long, conflicts_with("additional_libraries_version"))]
min: bool,
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -65,7 +79,12 @@ struct Args {
additional_libraries_version: String,

/// ダウンロードするデバイスを指定する(cudaはlinuxのみ)
#[arg(value_enum, long, default_value(<&str>::from(Device::default())))]
#[arg(
value_enum,
long,
default_value(<&str>::from(Device::default())),
required_if_eq("only", "additional-libraries")
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
)]
device: Device,

/// ダウンロードするcpuのアーキテクチャを指定する
Expand All @@ -87,6 +106,14 @@ struct Args {
additional_libraries_repo: RepoName,
}

#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, Hash)]
enum DownloadTarget {
Core,
Models,
AdditionalLibraries,
Dict,
}

#[derive(Default, ValueEnum, Display, IntoStaticStr, Clone, Copy, PartialEq)]
#[strum(serialize_all = "kebab-case")]
enum Device {
Expand Down Expand Up @@ -133,8 +160,9 @@ impl Os {
}
}

#[derive(parse_display::FromStr, Clone)]
#[derive(parse_display::FromStr, parse_display::Display, Clone)]
#[from_str(regex = "(?<owner>[a-zA-Z0-9_]+)/(?<repo>[a-zA-Z0-9_]+)")]
#[display("{owner}/{repo}")]
struct RepoName {
owner: String,
repo: String,
Expand All @@ -145,6 +173,8 @@ async fn main() -> anyhow::Result<()> {
setup_logger();

let Args {
only,
exclude,
min,
output,
version,
Expand All @@ -156,6 +186,51 @@ async fn main() -> anyhow::Result<()> {
additional_libraries_repo,
} = Args::parse();

let targets: HashSet<_> = if !only.is_empty() {
assert!(exclude.is_empty() && !min);
only.into_iter().collect()
} else if !exclude.is_empty() {
assert!(!min);
DownloadTarget::value_variants()
.iter()
.copied()
.filter(|t| !exclude.contains(t))
.collect()
} else if min {
[DownloadTarget::Core].into()
} else {
DownloadTarget::value_variants().iter().copied().collect()
};

if !(targets.contains(&DownloadTarget::Core) || targets.contains(&DownloadTarget::Models)) {
if version != "latest" {
warn!(
"`--version={version}`が指定されていますが、`core`も`models`もダウンロード対象から\
除外されています",
);
}
if core_repo.to_string() != DEFAULT_CORE_REPO {
warn!(
"`--core-repo={core_repo}`が指定されていますが、`core`も`models`もダウンロード対象\
から除外されています",
);
}
}
if !targets.contains(&DownloadTarget::AdditionalLibraries) {
if additional_libraries_version != "latest" {
warn!(
"`--additional-libraries-version={additional_libraries_version}`が指定されています\
が、`additional-libraries-version`はダウンロード対象から除外されています",
);
}
if additional_libraries_repo.to_string() != DEFAULT_ADDITIONAL_LIBRARIES_REPO {
warn!(
"`--additional-libraries-repo={additional_libraries_repo}`が指定されていますが、\
`additional-libraries-version`はダウンロード対象から除外されています",
);
}
}

let octocrab = &octocrab()?;

let core = find_gh_asset(octocrab, &core_repo, &version, |tag| {
Expand Down Expand Up @@ -202,21 +277,23 @@ async fn main() -> anyhow::Result<()> {

let mut tasks = JoinSet::new();

tasks.spawn(download_and_extract_from_gh(
core,
Stripping::FirstDir,
&output,
&progresses,
)?);

if !min {
if targets.contains(&DownloadTarget::Core) {
tasks.spawn(download_and_extract_from_gh(
core,
Stripping::FirstDir,
&output,
&progresses,
)?);
}
if targets.contains(&DownloadTarget::Models) {
tasks.spawn(download_and_extract_from_gh(
model,
Stripping::FirstDir,
&output.join("model"),
&progresses,
)?);

}
if targets.contains(&DownloadTarget::AdditionalLibraries) {
if let Some(additional_libraries) = additional_libraries {
tasks.spawn(download_and_extract_from_gh(
additional_libraries,
Expand All @@ -225,7 +302,8 @@ async fn main() -> anyhow::Result<()> {
&progresses,
)?);
}

}
if targets.contains(&DownloadTarget::Dict) {
tasks.spawn(download_and_extract_from_url(
&OPEN_JTALK_DIC_URL,
Stripping::None,
Expand Down Expand Up @@ -584,3 +662,30 @@ enum Stripping {
None,
FirstDir,
}

#[cfg(test)]
mod tests {
use clap::Parser as _;
use rstest::rstest;

use super::Args;

#[rstest]
#[case(&["", "--only", "core", "--exclude", "models"])]
#[case(&["", "--min", "--only", "core"])]
#[case(&["", "--min", "--exclude", "core"])]
fn it_denies_conflicting_options(#[case] args: &[&str]) {
let result = parse(args);
assert_eq!(Err(clap::error::ErrorKind::ArgumentConflict), result);
}

#[test]
fn it_denies_only_option_without_device_option() {
let result = parse(&["", "--only", "additional-libraries"]);
assert_eq!(Err(clap::error::ErrorKind::MissingRequiredArgument), result);
}

fn parse(args: &[&str]) -> Result<(), clap::error::ErrorKind> {
Args::try_parse_from(args).map(|_| ()).map_err(|e| e.kind())
}
}
2 changes: 1 addition & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ rev = "a16714ce16dec76fd0e3041a7acfa484921db3b5"
flate2 = "1.0.24"
heck = "0.4.0"
pretty_assertions = "1.3.0"
rstest = "0.15.0"
rstest.workspace = true
tar = "0.4.38"
test_util.workspace = true

Expand Down
Loading