From 9cfe4e6bf719a5d2a615a4769617e895a4e75a35 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sat, 20 Apr 2024 15:08:05 +1200 Subject: [PATCH 1/3] Initial StableDiffusion Video model converter --- .../stable_diffusion_video/.gitignore | 3 + .../stable_diffusion_video/README.md | 20 ++ .../stable_diffusion_video/config.py | 8 + .../stable_diffusion_video/config_unet.json | 85 +++++++ .../stable_diffusion_video/convert.py | 211 ++++++++++++++++++ .../stable_diffusion_video/models.py | 49 ++++ .../stable_diffusion_video/requirements.txt | 9 + .../stable_diffusion_video/sd_utils/ort.py | 117 ++++++++++ 8 files changed, 502 insertions(+) create mode 100644 OnnxStack.Converter/stable_diffusion_video/.gitignore create mode 100644 OnnxStack.Converter/stable_diffusion_video/README.md create mode 100644 OnnxStack.Converter/stable_diffusion_video/config.py create mode 100644 OnnxStack.Converter/stable_diffusion_video/config_unet.json create mode 100644 OnnxStack.Converter/stable_diffusion_video/convert.py create mode 100644 OnnxStack.Converter/stable_diffusion_video/models.py create mode 100644 OnnxStack.Converter/stable_diffusion_video/requirements.txt create mode 100644 OnnxStack.Converter/stable_diffusion_video/sd_utils/ort.py diff --git a/OnnxStack.Converter/stable_diffusion_video/.gitignore b/OnnxStack.Converter/stable_diffusion_video/.gitignore new file mode 100644 index 0000000..4cf6f30 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/.gitignore @@ -0,0 +1,3 @@ +/footprints/ +/cache/ +/result_*.png diff --git a/OnnxStack.Converter/stable_diffusion_video/README.md b/OnnxStack.Converter/stable_diffusion_video/README.md new file mode 100644 index 0000000..893782e --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/README.md @@ -0,0 +1,20 @@ +# OnnxStack.Converter + +## Requirements +```bash +pip install onnxruntime-directml +pip install olive-ai[directml] +python -m pip install -r requirements.txt +``` + +## Usage +```bash +convert.py --optimize --model_input '..\stable-video-diffusion-img2vid-xt' --model_output '..\converted' +``` +`--optimize` - Run the model optimization + +`--model_input` - Safetensor model to convert + +`--model_output` - Output for converted ONNX model (NOTE: This folder is deleted before each run) + +`--image_encoder` - Convert the optional image encoder diff --git a/OnnxStack.Converter/stable_diffusion_video/config.py b/OnnxStack.Converter/stable_diffusion_video/config.py new file mode 100644 index 0000000..7b1b47e --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/config.py @@ -0,0 +1,8 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +vae_sample_size = 512 +unet_sample_size = 24 +cross_attention_dim = 1280 \ No newline at end of file diff --git a/OnnxStack.Converter/stable_diffusion_video/config_unet.json b/OnnxStack.Converter/stable_diffusion_video/config_unet.json new file mode 100644 index 0000000..8256a88 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/config_unet.json @@ -0,0 +1,85 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "stabilityai/stable-video-diffusion-img2vid-xt", + "model_loader": "unet_load", + "model_script": "models.py", + "io_config": { + "input_names": [ "sample", "timestep", "encoder_hidden_states", "added_time_ids" ], + "output_names": [ "out_sample" ], + "dynamic_axes": { + "sample": {"0": "batch", "1": "frames", "2": "channel", "3": "height", "4": "width"}, + "timestep": {"0": "timestep"}, + "encoder_hidden_states": {"0": "batch", "1": "sequence_length", "2": "cross_attention_dim"}, + "added_time_ids": {"0": "batch", "1": "num_additional_ids" } + } + }, + "dummy_inputs_func": "unet_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": [ + { + "device": "gpu", + "execution_providers": [ + "DmlExecutionProvider" + ] + } + ] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "models.py", + "dataloader_func": "unet_data_loader", + "batch_size": 2 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 16, + "save_as_external_data": true, + "all_tensors_to_one_file": true + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "unet", + "opt_level": 0, + "float16": true, + "use_gpu": true, + "keep_io_types": true + } + } + }, + "pass_flows": [ + ["convert", "optimize"] + ], + "engine": { + "log_severity_level": 0, + "evaluator": "common_evaluator", + "evaluate_input_model": false, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "unet", + "output_dir": "footprints" + } +} diff --git a/OnnxStack.Converter/stable_diffusion_video/convert.py b/OnnxStack.Converter/stable_diffusion_video/convert.py new file mode 100644 index 0000000..fe7af38 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/convert.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import argparse +import json +import shutil +import sys +import warnings +from pathlib import Path +from typing import Dict + +import config +import torch +from diffusers import DiffusionPipeline +from packaging import version + +from olive.common.utils import set_tempdir +from olive.workflows import run as olive_run + + +# pylint: disable=redefined-outer-name +# ruff: noqa: TID252, T201 + + +def save_image(result, batch_size, provider, num_images, images_saved, image_callback=None): + passed_safety_checker = 0 + for image_index in range(batch_size): + if result.nsfw_content_detected is None or not result.nsfw_content_detected[image_index]: + passed_safety_checker += 1 + if images_saved < num_images: + output_path = f"result_{images_saved}.png" + result.images[image_index].save(output_path) + if image_callback: + image_callback(images_saved, output_path) + images_saved += 1 + print(f"Generated {output_path}") + print(f"Inference Batch End ({passed_safety_checker}/{batch_size} images).") + print("Images passed the safety checker.") + return images_saved + + +def run_inference_loop( + pipeline, + prompt, + num_images, + batch_size, + image_size, + num_inference_steps, + guidance_scale, + strength: float, + provider: str, + image_callback=None, + step_callback=None, +): + images_saved = 0 + + def update_steps(step, timestep, latents): + if step_callback: + step_callback((images_saved // batch_size) * num_inference_steps + step) + + while images_saved < num_images: + print(f"\nInference Batch Start (batch size = {batch_size}).") + + kwargs = {} + + result = pipeline( + [prompt] * batch_size, + num_inference_steps=num_inference_steps, + callback=update_steps if step_callback else None, + height=image_size, + width=image_size, + guidance_scale=guidance_scale, + **kwargs, + ) + + images_saved = save_image(result, batch_size, provider, num_images, images_saved, image_callback) + + +def update_config_with_provider(config: Dict, provider: str): + if provider == "dml": + # DirectML EP is the default, so no need to update config. + return config + elif provider == "cuda": + from sd_utils.ort import update_cuda_config + + return update_cuda_config(config) + else: + raise ValueError(f"Unsupported provider: {provider}") + + +def optimize( + model_input: str, + model_output: Path, + provider: str, + image_encoder: bool +): + from google.protobuf import __version__ as protobuf_version + + # protobuf 4.x aborts with OOM when optimizing unet + if version.parse(protobuf_version) > version.parse("3.20.3"): + print("This script requires protobuf 3.20.3. Please ensure your package version matches requirements.txt.") + sys.exit(1) + + model_dir = model_input + script_dir = Path(__file__).resolve().parent + + # Clean up previously optimized models, if any. + shutil.rmtree(script_dir / "footprints", ignore_errors=True) + shutil.rmtree(model_output, ignore_errors=True) + + # Load the entire PyTorch pipeline to ensure all models and their configurations are downloaded and cached. + # This avoids an issue where the non-ONNX components (tokenizer, scheduler, and feature extractor) are not + # automatically cached correctly if individual models are fetched one at a time. + print("Download stable diffusion PyTorch pipeline...") + pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float32, **{"local_files_only": True}) + # config.vae_sample_size = pipeline.vae.config.sample_size + # config.cross_attention_dim = pipeline.unet.config.cross_attention_dim + # config.unet_sample_size = pipeline.unet.config.sample_size + + model_info = {} + + submodel_names = [ "unet" ] + + if image_encoder: + submodel_names.append("image_encoder") + + for submodel_name in submodel_names: + print(f"\nOptimizing {submodel_name}") + + olive_config = None + with (script_dir / f"config_{submodel_name}.json").open() as fin: + olive_config = json.load(fin) + olive_config = update_config_with_provider(olive_config, provider) + olive_config["input_model"]["config"]["model_path"] = model_dir + + run_res = olive_run(olive_config) + + from sd_utils.ort import save_optimized_onnx_submodel + + save_optimized_onnx_submodel(submodel_name, provider, model_info) + + from sd_utils.ort import save_onnx_pipeline + + save_onnx_pipeline( + model_info, model_output, pipeline, submodel_names + ) + + return model_info + + +def parse_common_args(raw_args): + parser = argparse.ArgumentParser("Common arguments") + parser.add_argument("--model_input", default="stable-diffusion-v1-5", type=str) + parser.add_argument("--model_output", default="stable-diffusion-v1-5", type=Path) + parser.add_argument("--image_encoder",action="store_true", help="Create image encoder model") + parser.add_argument("--provider", default="dml", type=str, choices=["dml", "cuda"], help="Execution provider to use") + parser.add_argument("--optimize", action="store_true", help="Runs the optimization step") + parser.add_argument("--clean_cache", action="store_true", help="Deletes the Olive cache") + parser.add_argument("--test_unoptimized", action="store_true", help="Use unoptimized model for inference") + parser.add_argument("--tempdir", default=None, type=str, help="Root directory for tempfile directories and files") + return parser.parse_known_args(raw_args) + + +def parse_ort_args(raw_args): + parser = argparse.ArgumentParser("ONNX Runtime arguments") + + parser.add_argument( + "--static_dims", + action="store_true", + help="DEPRECATED (now enabled by default). Use --dynamic_dims to disable static_dims.", + ) + parser.add_argument("--dynamic_dims", action="store_true", help="Disable static shape optimization") + + return parser.parse_known_args(raw_args) + + +def main(raw_args=None): + common_args, extra_args = parse_common_args(raw_args) + + provider = common_args.provider + model_input = common_args.model_input + model_output = common_args.model_output + + script_dir = Path(__file__).resolve().parent + + + if common_args.clean_cache: + shutil.rmtree(script_dir / "cache", ignore_errors=True) + + ort_args = None, None + ort_args, extra_args = parse_ort_args(extra_args) + + if common_args.optimize or not model_output.exists(): + set_tempdir(common_args.tempdir) + + # TODO(jstoecker): clean up warning filter (mostly during conversion from torch to ONNX) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + from sd_utils.ort import validate_args + + validate_args(ort_args, common_args.provider) + optimize(common_args.model_input, common_args.model_output, common_args.provider, common_args.image_encoder) + + if not common_args.optimize: + print("TODO: Create OnnxStableCascadePipeline") + + +if __name__ == "__main__": + main() diff --git a/OnnxStack.Converter/stable_diffusion_video/models.py b/OnnxStack.Converter/stable_diffusion_video/models.py new file mode 100644 index 0000000..c5874b4 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/models.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import config +import torch +from typing import Union, Optional, Tuple +from diffusers import UNetSpatioTemporalConditionModel +from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection +from dataclasses import dataclass + +# Helper latency-only dataloader that creates random tensors with no label +class RandomDataLoader: + def __init__(self, create_inputs_func, batchsize, torch_dtype): + self.create_input_func = create_inputs_func + self.batchsize = batchsize + self.torch_dtype = torch_dtype + + def __getitem__(self, idx): + label = None + return self.create_input_func(self.batchsize, self.torch_dtype), label + + + +# ----------------------------------------------------------------------------- +# UNET +# ----------------------------------------------------------------------------- + +def unet_inputs(batchsize, torch_dtype, is_conversion_inputs=False): + inputs = { + "sample": torch.rand((batchsize, 25, 8, 72, 128), dtype=torch_dtype), + "timestep": torch.rand((1,), dtype=torch_dtype), + "encoder_hidden_states": torch.rand((batchsize , 1, 1024), dtype=torch_dtype), + "added_time_ids": torch.rand((batchsize, 3), dtype=torch_dtype) + } + return inputs + + +def unet_load(model_name): + model = UNetSpatioTemporalConditionModel.from_pretrained(model_name, subfolder="unet") + return model + + +def unet_conversion_inputs(model=None): + return tuple(unet_inputs(1, torch.float32, True).values()) + + +def unet_data_loader(data_dir, batchsize, *args, **kwargs): + return RandomDataLoader(unet_inputs, batchsize, torch.float16) diff --git a/OnnxStack.Converter/stable_diffusion_video/requirements.txt b/OnnxStack.Converter/stable_diffusion_video/requirements.txt new file mode 100644 index 0000000..15b9198 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/requirements.txt @@ -0,0 +1,9 @@ +accelerate +diffusers +onnx +pillow +protobuf==3.20.3 # protobuf 4.x aborts with OOM when optimizing unet +tabulate +torch +transformers +onnxruntime-directml>=1.16.0 diff --git a/OnnxStack.Converter/stable_diffusion_video/sd_utils/ort.py b/OnnxStack.Converter/stable_diffusion_video/sd_utils/ort.py new file mode 100644 index 0000000..72746f7 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/sd_utils/ort.py @@ -0,0 +1,117 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +import json +import shutil +import sys +from pathlib import Path +from typing import Dict + +import onnxruntime as ort +from diffusers import OnnxRuntimeModel, StableCascadePriorPipeline +from onnxruntime import __version__ as OrtVersion +from packaging import version + +from olive.model import ONNXModelHandler + +# ruff: noqa: TID252, T201 + + +def update_cuda_config(config: Dict): + if version.parse(OrtVersion) < version.parse("1.17.0"): + # disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models + config["passes"]["optimize_cuda"]["config"]["optimization_options"] = {"enable_skip_group_norm": False} + config["pass_flows"] = [["convert", "optimize_cuda"]] + config["systems"]["local_system"]["config"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"] + return config + + +def validate_args(args, provider): + ort.set_default_logger_severity(4) + if args.static_dims: + print( + "WARNING: the --static_dims option is deprecated, and static shape optimization is enabled by default. " + "Use --dynamic_dims to disable static shape optimization." + ) + + validate_ort_version(provider) + + +def validate_ort_version(provider: str): + if provider == "dml" and version.parse(OrtVersion) < version.parse("1.16.0"): + print("This script requires onnxruntime-directml 1.16.0 or newer") + sys.exit(1) + elif provider == "cuda" and version.parse(OrtVersion) < version.parse("1.17.0"): + if version.parse(OrtVersion) < version.parse("1.16.2"): + print("This script requires onnxruntime-gpu 1.16.2 or newer") + sys.exit(1) + print( + f"WARNING: onnxruntime {OrtVersion} has known issues with shape inference for SkipGroupNorm. Will disable" + " skip_group_norm fusion. onnxruntime-gpu 1.17.0 or newer is strongly recommended!" + ) + + +def save_optimized_onnx_submodel(submodel_name, provider, model_info): + footprints_file_path = ( + Path(__file__).resolve().parents[1] / "footprints" / f"{submodel_name}_gpu-{provider}_footprints.json" + ) + with footprints_file_path.open("r") as footprint_file: + footprints = json.load(footprint_file) + + conversion_footprint = None + optimizer_footprint = None + for footprint in footprints.values(): + if footprint["from_pass"] == "OnnxConversion": + conversion_footprint = footprint + elif footprint["from_pass"] == "OrtTransformersOptimization": + optimizer_footprint = footprint + + assert conversion_footprint + assert optimizer_footprint + + unoptimized_olive_model = ONNXModelHandler(**conversion_footprint["model_config"]["config"]) + optimized_olive_model = ONNXModelHandler(**optimizer_footprint["model_config"]["config"]) + + model_info[submodel_name] = { + "unoptimized": { + "path": Path(unoptimized_olive_model.model_path), + "data": Path(unoptimized_olive_model.model_path + ".data"), + }, + "optimized": { + "path": Path(optimized_olive_model.model_path), + "data": Path(optimized_olive_model.model_path + ".data"), + }, + } + + print(f"Unoptimized Model : {model_info[submodel_name]['unoptimized']['path']}") + print(f"Optimized Model : {model_info[submodel_name]['optimized']['path']}") + + +def save_onnx_pipeline( + model_info, model_output, pipeline, submodel_names +): + # Save the unoptimized models in a directory structure that the diffusers library can load and run. + # This is optional, and the optimized models can be used directly in a custom pipeline if desired. + # print("\nCreating ONNX pipeline...") + + # TODO: Create OnnxStableCascadePipeline + + # Create a copy of the unoptimized model directory, then overwrite with optimized models from the olive cache. + print("Copying optimized models...") + for passType in ["optimized", "unoptimized"]: + model_dir = model_output / passType + for submodel_name in submodel_names: + src_path = model_info[submodel_name][passType]["path"] # model.onnx + src_data_path = model_info[submodel_name][passType]["data"]# model.onnx.data + + dst_path = model_dir / submodel_name + if not os.path.exists(dst_path): + os.makedirs(dst_path, exist_ok=True) + + shutil.copyfile(src_path, dst_path / "model.onnx") + if os.path.exists(src_data_path): + shutil.copyfile(src_data_path, dst_path / "model.onnx.data") + + print(f"The converted model is located here: {model_output}") From f7d53bcbe8bc6214e1dd8c5c63450958226ea39c Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sat, 20 Apr 2024 15:13:57 +1200 Subject: [PATCH 2/3] optimize --- .../stable_diffusion_video/config_unet.json | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/OnnxStack.Converter/stable_diffusion_video/config_unet.json b/OnnxStack.Converter/stable_diffusion_video/config_unet.json index 8256a88..fd91a29 100644 --- a/OnnxStack.Converter/stable_diffusion_video/config_unet.json +++ b/OnnxStack.Converter/stable_diffusion_video/config_unet.json @@ -65,7 +65,32 @@ "opt_level": 0, "float16": true, "use_gpu": true, - "keep_io_types": true + "keep_io_types": true, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false, + "group_norm_channels_last": false + }, + "force_fp32_ops": ["RandomNormalLike"], + "force_fp16_inputs": { + "GroupNorm": [0, 1, 2] + } } } }, From 100739ef65eb039e86cee6e715d16beeb4953921 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sat, 20 Apr 2024 17:31:35 +1200 Subject: [PATCH 3/3] Vae models conversion --- .../config_vae_decoder.json | 105 ++++++++++++++++++ .../config_vae_encoder.json | 103 +++++++++++++++++ .../stable_diffusion_video/convert.py | 2 +- .../stable_diffusion_video/models.py | 54 ++++++++- 4 files changed, 262 insertions(+), 2 deletions(-) create mode 100644 OnnxStack.Converter/stable_diffusion_video/config_vae_decoder.json create mode 100644 OnnxStack.Converter/stable_diffusion_video/config_vae_encoder.json diff --git a/OnnxStack.Converter/stable_diffusion_video/config_vae_decoder.json b/OnnxStack.Converter/stable_diffusion_video/config_vae_decoder.json new file mode 100644 index 0000000..88950b7 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/config_vae_decoder.json @@ -0,0 +1,105 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "stabilityai/stable-video-diffusion-img2vid-xt", + "model_loader": "vae_decoder_load", + "model_script": "models.py", + "io_config": { + "input_names": [ "latent_sample", "num_frames" ], + "output_names": [ "sample" ], + "dynamic_axes": { + "latent_sample": { "0": "batch", "1": "channels", "2": "height", "3": "width" } + } + }, + "dummy_inputs_func": "vae_decoder_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": [ + { + "device": "gpu", + "execution_providers": [ + "DmlExecutionProvider" + ] + } + ] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "models.py", + "dataloader_func": "vae_decoder_data_loader", + "batch_size": 1 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 16 + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "vae", + "opt_level": 0, + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false, + "group_norm_channels_last": false + }, + "force_fp32_ops": ["RandomNormalLike"], + "force_fp16_inputs": { + "GroupNorm": [0, 1, 2] + } + } + } + }, + "pass_flows": [ + ["convert", "optimize"] + ], + "engine": { + "log_severity_level": 0, + "evaluator": "common_evaluator", + "evaluate_input_model": false, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "vae_decoder", + "output_dir": "footprints" + } +} diff --git a/OnnxStack.Converter/stable_diffusion_video/config_vae_encoder.json b/OnnxStack.Converter/stable_diffusion_video/config_vae_encoder.json new file mode 100644 index 0000000..da6c806 --- /dev/null +++ b/OnnxStack.Converter/stable_diffusion_video/config_vae_encoder.json @@ -0,0 +1,103 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_path": "stabilityai/stable-video-diffusion-img2vid-xt", + "model_loader": "vae_encoder_load", + "model_script": "models.py", + "io_config": { + "input_names": [ "sample" ], + "output_names": [ "latent_sample" ], + "dynamic_axes": { "sample": { "0": "batch", "1": "channels", "2": "height", "3": "width" } } + }, + "dummy_inputs_func": "vae_encoder_conversion_inputs" + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": [ + { + "device": "gpu", + "execution_providers": [ + "DmlExecutionProvider" + ] + } + ] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [{"name": "avg"}], + "user_config": { + "user_script": "models.py", + "dataloader_func": "vae_encoder_data_loader", + "batch_size": 1 + } + } + ] + } + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 16 + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "vae", + "opt_level": 0, + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false, + "group_norm_channels_last": false + }, + "force_fp32_ops": ["RandomNormalLike"], + "force_fp16_inputs": { + "GroupNorm": [0, 1, 2] + } + } + } + }, + "pass_flows": [ + ["convert", "optimize"] + ], + "engine": { + "log_severity_level": 0, + "evaluator": "common_evaluator", + "evaluate_input_model": false, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_name": "vae_encoder", + "output_dir": "footprints" + } +} diff --git a/OnnxStack.Converter/stable_diffusion_video/convert.py b/OnnxStack.Converter/stable_diffusion_video/convert.py index fe7af38..537530f 100644 --- a/OnnxStack.Converter/stable_diffusion_video/convert.py +++ b/OnnxStack.Converter/stable_diffusion_video/convert.py @@ -120,7 +120,7 @@ def optimize( model_info = {} - submodel_names = [ "unet" ] + submodel_names = [ "vae_encoder", "vae_decoder" ] if image_encoder: submodel_names.append("image_encoder") diff --git a/OnnxStack.Converter/stable_diffusion_video/models.py b/OnnxStack.Converter/stable_diffusion_video/models.py index c5874b4..3462d56 100644 --- a/OnnxStack.Converter/stable_diffusion_video/models.py +++ b/OnnxStack.Converter/stable_diffusion_video/models.py @@ -5,7 +5,7 @@ import config import torch from typing import Union, Optional, Tuple -from diffusers import UNetSpatioTemporalConditionModel +from diffusers import UNetSpatioTemporalConditionModel, AutoencoderKLTemporalDecoder from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection from dataclasses import dataclass @@ -47,3 +47,55 @@ def unet_conversion_inputs(model=None): def unet_data_loader(data_dir, batchsize, *args, **kwargs): return RandomDataLoader(unet_inputs, batchsize, torch.float16) + + + +# ----------------------------------------------------------------------------- +# VAE ENCODER +# ----------------------------------------------------------------------------- + + +def vae_encoder_inputs(batchsize, torch_dtype): + return {"sample": torch.rand((batchsize, 3, 72, 128), dtype=torch_dtype)} + + +def vae_encoder_load(model_name): + model = AutoencoderKLTemporalDecoder.from_pretrained(model_name, subfolder="vae", use_safetensors=True) + model.forward = lambda sample: model.encode(sample)[0].sample() + return model + + +def vae_encoder_conversion_inputs(model=None): + return tuple(vae_encoder_inputs(1, torch.float32).values()) + + +def vae_encoder_data_loader(data_dir, batchsize, *args, **kwargs): + return RandomDataLoader(vae_encoder_inputs, batchsize, torch.float16) + + + + +# ----------------------------------------------------------------------------- +# VAE DECODER +# ----------------------------------------------------------------------------- + + +def vae_decoder_inputs(batchsize, torch_dtype): + return { + "latent_sample": torch.rand((batchsize, 4, 72, 128), dtype=torch_dtype), + "num_frames": 1, + } + + +def vae_decoder_load(model_name): + model = AutoencoderKLTemporalDecoder.from_pretrained(model_name, subfolder="vae", use_safetensors=True) + model.forward = model.decode + return model + + +def vae_decoder_conversion_inputs(model=None): + return tuple(vae_decoder_inputs(1, torch.float32).values()) + + +def vae_decoder_data_loader(data_dir, batchsize, *args, **kwargs): + return RandomDataLoader(vae_decoder_inputs, batchsize, torch.float16) \ No newline at end of file