diff --git a/.gitignore b/.gitignore index eeacc3c..7333653 100755 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,6 @@ cython_debug/ .python-version build/ dist/ + +*.aivm +*.aivmx diff --git a/aivmlib/__init__.py b/aivmlib/__init__.py index 1487cb5..f6e8624 100755 --- a/aivmlib/__init__.py +++ b/aivmlib/__init__.py @@ -51,7 +51,7 @@ def generate_aivm_metadata( style_vectors_file.seek(0) # Style-Bert-VITS2 系の音声合成モデルの場合 - if model_architecture.startswith('Style-Bert-VITS2'): + if model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]: # ハイパーパラメータファイル (JSON) を読み込んだ後、Pydantic でバリデーションする hyper_parameters_content = hyper_parameters_file.read().decode('utf-8') @@ -147,7 +147,7 @@ def validate_aivm_metadata(raw_metadata: dict[str, str]) -> AivmMetadata: # ハイパーパラメータのバリデーション if 'aivm_hyper_parameters' in raw_metadata: try: - if aivm_manifest.model_architecture.startswith('Style-Bert-VITS2'): + if aivm_manifest.model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]: aivm_hyper_parameters = StyleBertVITS2HyperParameters.model_validate_json(raw_metadata['aivm_hyper_parameters']) else: raise AivmValidationError(f"Unsupported hyper-parameters for model architecture: {aivm_manifest.model_architecture}.") @@ -395,7 +395,7 @@ def apply_aivm_manifest_to_hyper_parameters(aivm_metadata: AivmMetadata) -> None """ # Style-Bert-VITS2 系の音声合成モデルの場合 - if aivm_metadata.manifest.model_architecture.startswith('Style-Bert-VITS2'): + if aivm_metadata.manifest.model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]: # スタイルベクトルが設定されていなければエラー if aivm_metadata.style_vectors is None: diff --git a/aivmlib/__main__.py b/aivmlib/__main__.py index e788df0..e321062 100755 --- a/aivmlib/__main__.py +++ b/aivmlib/__main__.py @@ -1,18 +1,13 @@ import rich +import traceback import typer from pathlib import Path from rich.rule import Rule from rich.style import Style from typing import Annotated, Union -from aivmlib import ( - generate_aivm_metadata, - read_aivm_metadata, - read_aivmx_metadata, - write_aivm_metadata, - write_aivmx_metadata, -) +import aivmlib from aivmlib.schemas.aivm_manifest import ModelArchitecture @@ -24,38 +19,40 @@ def show_metadata( file_path: Annotated[Path, typer.Argument(help='Path to the AIVM / AIVMX file')] ): """ - 指定されたパスの AIVM / AIVMX ファイル内に記録されている AIVM メタデータを見やすく出力する + 指定されたパスの AIVM / AIVMX ファイル内に格納されている AIVM メタデータを見やすく出力する """ try: with file_path.open('rb') as file: if file_path.suffix == '.aivmx': - metadata = read_aivmx_metadata(file) + metadata = aivmlib.read_aivmx_metadata(file) else: - metadata = read_aivm_metadata(file) + metadata = aivmlib.read_aivm_metadata(file) for speaker in metadata.manifest.speakers: speaker.icon = '(Image Base64 DataURL)' for style in speaker.styles: - style.icon = '(Image Base64 DataURL)' + if style.icon: + style.icon = '(Image Base64 DataURL)' for sample in style.voice_samples: sample.audio = '(Audio Base64 DataURL)' - rich.print(Rule(title='AIVM Manifest:', characters='=', style=Style(color='#E33157'))) + rich.print(Rule(title='AIVM Manifest:', characters='=', style=Style(color='#41A2EC'))) rich.print(metadata.manifest) - rich.print(Rule(title='Hyper Parameters:', characters='=', style=Style(color='#E33157'))) + rich.print(Rule(title='Hyper Parameters:', characters='=', style=Style(color='#41A2EC'))) rich.print(metadata.hyper_parameters) - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) except Exception as e: - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) rich.print(f'[red]Error reading AIVM or AIVMX file: {e}[/red]') - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + rich.print(traceback.format_exc()) + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) @app.command() def create_aivm( output_path: Annotated[Path, typer.Option('-o', '--output', help='Path to the output AIVM file')], safetensors_model_path: Annotated[Path, typer.Option('-m', '--model', help='Path to the Safetensors model file')], - hyper_parameters_path: Annotated[Path, typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file')], + hyper_parameters_path: Annotated[Union[Path, None], typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file (optional)')] = None, style_vectors_path: Annotated[Union[Path, None], typer.Option('-s', '--style-vectors', help='Path to the style vectors file (optional)')] = None, model_architecture: Annotated[ModelArchitecture, typer.Option('-a', '--model-architecture', help='Model architecture')] = ModelArchitecture.StyleBertVITS2JPExtra, ): @@ -64,31 +61,68 @@ def create_aivm( それを書き込んだ仮の AIVM ファイルを生成する """ + # 拡張子チェック + if safetensors_model_path.suffix != '.safetensors': + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print('[red]Safetensors model file must have a .safetensors extension.[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + if output_path.suffix != '.aivm': + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print('[red]Output file must have a .aivm extension.[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + try: + # アーキテクチャに合わせて未指定のファイルパスを自動設定 + if model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]: + model_dir = safetensors_model_path.parent + if not hyper_parameters_path: + hyper_parameters_path = model_dir / 'config.json' + if not style_vectors_path: + style_vectors_path = model_dir / 'style_vectors.npy' + + # 必要なファイルが存在しない場合はエラーを発生させる + if not hyper_parameters_path.exists(): + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'[red]Hyper parameters file not found: {hyper_parameters_path}[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + if not style_vectors_path.exists(): + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'[red]Style vectors file not found: {style_vectors_path}[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + else: + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'[red]Model architecture {model_architecture} is not supported.[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + with hyper_parameters_path.open('rb') as hyper_parameters_file: - style_vectors_file = style_vectors_path.open('rb') if style_vectors_path else None - metadata = generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file) - if style_vectors_file: - style_vectors_file.close() - - with safetensors_model_path.open('rb') as safetensors_file: - new_aivm_file_content = write_aivm_metadata(safetensors_file, metadata) - with output_path.open('wb') as f: - f.write(new_aivm_file_content) - rich.print(Rule(characters='=', style=Style(color='#E33157'))) - rich.print(f'Generated AIVM file: {output_path}') - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + with style_vectors_path.open('rb') as style_vectors_file: + metadata = aivmlib.generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file) + + with safetensors_model_path.open('rb') as safetensors_file: + new_aivm_file_content = aivmlib.write_aivm_metadata(safetensors_file, metadata) + with output_path.open('wb') as f: + f.write(new_aivm_file_content) + + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'Generated AIVM file: {output_path}') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) except Exception as e: - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) rich.print(f'[red]Error creating AIVM file: {e}[/red]') - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + rich.print(traceback.format_exc()) + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) @app.command() def create_aivmx( output_path: Annotated[Path, typer.Option('-o', '--output', help='Path to the output AIVMX file')], onnx_model_path: Annotated[Path, typer.Option('-m', '--model', help='Path to the ONNX model file')], - hyper_parameters_path: Annotated[Path, typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file')], + hyper_parameters_path: Annotated[Union[Path, None], typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file (optional)')] = None, style_vectors_path: Annotated[Union[Path, None], typer.Option('-s', '--style-vectors', help='Path to the style vectors file (optional)')] = None, model_architecture: Annotated[ModelArchitecture, typer.Option('-a', '--model-architecture', help='Model architecture')] = ModelArchitecture.StyleBertVITS2JPExtra, ): @@ -97,24 +131,61 @@ def create_aivmx( それを書き込んだ仮の AIVMX ファイルを生成する """ + # 拡張子チェック + if onnx_model_path.suffix != '.onnx': + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print('[red]ONNX model file must have a .onnx extension.[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + if output_path.suffix != '.aivmx': + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print('[red]Output file must have a .aivmx extension.[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + try: + # アーキテクチャに合わせて未指定のファイルパスを自動設定 + if model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]: + model_dir = onnx_model_path.parent + if not hyper_parameters_path: + hyper_parameters_path = model_dir / 'config.json' + if not style_vectors_path: + style_vectors_path = model_dir / 'style_vectors.npy' + + # 必要なファイルが存在しない場合はエラーを発生させる + if not hyper_parameters_path.exists(): + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'[red]Hyper parameters file not found: {hyper_parameters_path}[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + if not style_vectors_path.exists(): + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'[red]Style vectors file not found: {style_vectors_path}[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + else: + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'[red]Model architecture {model_architecture} is not supported.[/red]') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + return + with hyper_parameters_path.open('rb') as hyper_parameters_file: - style_vectors_file = style_vectors_path.open('rb') if style_vectors_path else None - metadata = generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file) - if style_vectors_file: - style_vectors_file.close() - - with onnx_model_path.open('rb') as onnx_file: - new_aivmx_file_content = write_aivmx_metadata(onnx_file, metadata) - with output_path.open('wb') as f: - f.write(new_aivmx_file_content) - rich.print(Rule(characters='=', style=Style(color='#E33157'))) - rich.print(f'Generated AIVMX file: {output_path}') - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + with style_vectors_path.open('rb') as style_vectors_file: + metadata = aivmlib.generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file) + + with onnx_model_path.open('rb') as onnx_file: + new_aivmx_file_content = aivmlib.write_aivmx_metadata(onnx_file, metadata) + with output_path.open('wb') as f: + f.write(new_aivmx_file_content) + + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) + rich.print(f'Generated AIVMX file: {output_path}') + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) except Exception as e: - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) rich.print(f'[red]Error creating AIVMX file: {e}[/red]') - rich.print(Rule(characters='=', style=Style(color='#E33157'))) + rich.print(traceback.format_exc()) + rich.print(Rule(characters='=', style=Style(color='#41A2EC'))) if __name__ == '__main__':