From f511c666f57016e4d7e23ee3d19cf56b25a75a2b Mon Sep 17 00:00:00 2001 From: anzr299 Date: Fri, 10 Jan 2025 11:21:37 +0400 Subject: [PATCH 1/7] Fix wrapper class for sd3 --- .../sd3_torch_fx_helper.py | 56 ++++--------------- 1 file changed, 11 insertions(+), 45 deletions(-) diff --git a/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py b/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py index b4253e5cfd9..138cb63b486 100644 --- a/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py +++ b/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py @@ -17,71 +17,37 @@ def get_sd3_pipeline(model_id="stabilityai/stable-diffusion-3-medium-diffusers") # This function takes in the models of a SD3 pipeline in the torch fx representation and returns an SD3 pipeline with wrapped models. def init_pipeline(models_dict, configs_dict, model_id="stabilityai/stable-diffusion-3-medium-diffusers"): wrapped_models = {} - def wrap_model(pipe_model, base_class, config): - base_class = (base_class,) if not isinstance(base_class, tuple) else base_class - - class WrappedModel(*base_class): + class ModelWrapper(base_class): def __init__(self, model, config): - cls_name = base_class[0].__name__ + cls_name = base_class.__name__ if isinstance(config, dict): super().__init__(**config) else: super().__init__(config) + + modules_to_delete = [name for name in self._modules.keys()] + for name in modules_to_delete: + del self._modules[name] + if cls_name == "AutoencoderKL": self.encoder = model.encoder self.decoder = model.decoder else: self.model = model - def forward(self, *args, **kwargs): + kwargs.pop('joint_attention_kwargs', None) + kwargs.pop('return_dict', None) return self.model(*args, **kwargs) - class WrappedTransformer(*base_class): - @register_to_config - def __init__( - self, - model, - sample_size, - patch_size, - in_channels, - num_layers, - attention_head_dim, - num_attention_heads, - joint_attention_dim, - caption_projection_dim, - pooled_projection_dim, - out_channels, - pos_embed_max_size, - dual_attention_layers, - qk_norm, - ): - super().__init__() - self.model = model + return ModelWrapper(pipe_model, config) - def forward(self, *args, **kwargs): - del kwargs["joint_attention_kwargs"] - del kwargs["return_dict"] - return self.model(*args, **kwargs) - - if len(base_class) > 1: - return WrappedTransformer(pipe_model, **config) - return WrappedModel(pipe_model, config) - - wrapped_models["transformer"] = wrap_model( - models_dict["transformer"], - ( - ModelMixin, - ConfigMixin, - ), - configs_dict["transformer"], - ) + wrapped_models["transformer"] = wrap_model(models_dict["transformer"], SD3Transformer2DModel, configs_dict["transformer"]) wrapped_models["vae"] = wrap_model(models_dict["vae"], AutoencoderKL, configs_dict["vae"]) wrapped_models["text_encoder"] = wrap_model(models_dict["text_encoder"], CLIPTextModelWithProjection, configs_dict["text_encoder"]) wrapped_models["text_encoder_2"] = wrap_model(models_dict["text_encoder_2"], CLIPTextModelWithProjection, configs_dict["text_encoder_2"]) pipe = StableDiffusion3Pipeline.from_pretrained(model_id, text_encoder_3=None, tokenizer_3=None, **wrapped_models) - return pipe From 14b41d4e26a221c704babff4812e2446c411f901 Mon Sep 17 00:00:00 2001 From: Aamir Nazir Date: Tue, 14 Jan 2025 11:29:47 +0400 Subject: [PATCH 2/7] black fix --- notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py b/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py index 138cb63b486..51705a55854 100644 --- a/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py +++ b/notebooks/stable-diffusion-v3/sd3_torch_fx_helper.py @@ -17,6 +17,7 @@ def get_sd3_pipeline(model_id="stabilityai/stable-diffusion-3-medium-diffusers") # This function takes in the models of a SD3 pipeline in the torch fx representation and returns an SD3 pipeline with wrapped models. def init_pipeline(models_dict, configs_dict, model_id="stabilityai/stable-diffusion-3-medium-diffusers"): wrapped_models = {} + def wrap_model(pipe_model, base_class, config): class ModelWrapper(base_class): def __init__(self, model, config): @@ -35,9 +36,10 @@ def __init__(self, model, config): self.decoder = model.decoder else: self.model = model + def forward(self, *args, **kwargs): - kwargs.pop('joint_attention_kwargs', None) - kwargs.pop('return_dict', None) + kwargs.pop("joint_attention_kwargs", None) + kwargs.pop("return_dict", None) return self.model(*args, **kwargs) return ModelWrapper(pipe_model, config) From e88101419c823b80d887898190b2f440c6455ad1 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 15 Jan 2025 11:36:51 +0400 Subject: [PATCH 3/7] black fix --- notebooks/mobileclip-video-search/mobileclip-video-search.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/mobileclip-video-search/mobileclip-video-search.ipynb b/notebooks/mobileclip-video-search/mobileclip-video-search.ipynb index 6bac8f879ca..1a9758756c6 100644 --- a/notebooks/mobileclip-video-search/mobileclip-video-search.ipynb +++ b/notebooks/mobileclip-video-search/mobileclip-video-search.ipynb @@ -112,7 +112,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install -q \"openvino>=2024.5.0\" \n", + "%pip install -q \"openvino>=2024.5.0\"\n", "%pip install -q \"git+https://github.com/huggingface/optimum-intel.git\" \"transformers>=4.45\" \"tokenizers>=0.20\" --extra-index-url https://download.pytorch.org/whl/cpu" ] }, From 3904c133bc2913931ad239b144577705af1c42d7 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 15 Jan 2025 14:53:07 +0400 Subject: [PATCH 4/7] init --- .../stable-diffusion-v3-torch-fx.ipynb | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb index 4fed1daafed..eba0a11081a 100644 --- a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb +++ b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -296,6 +296,17 @@ "unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n", "unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n", "\n", + "#Feature map height and width are dynamic\n", + "fm_height = torch.export.Dim('fm_height', min=16, max=256)\n", + "fm_width = torch.export.Dim('fm_width', min=16, max=256)\n", + "dim = torch.export.Dim('dim', min=1, max=16)\n", + "fm_height = 16*dim\n", + "fm_width = 16*dim\n", + "\n", + "dynamic_shapes = {\"sample\": {2: fm_height, 3: fm_width}}\n", + "#iterate through the unet kwargs and set only hidden state kwarg to dynamic\n", + "dynamic_shapes_transformer = {key: (None if key != \"hidden_states\" else {2: fm_height, 3: fm_width}) for key in unet_kwargs.keys()}\n", + "\n", "with torch.no_grad():\n", " with disable_patching():\n", " text_encoder = torch.export.export_for_training(\n", @@ -308,10 +319,10 @@ " args=(text_encoder_input,),\n", " kwargs=(text_encoder_kwargs),\n", " ).module()\n", - " pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,)).module()\n", - " pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,)).module()\n", + " pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,), dynamic_shapes=dynamic_shapes).module()\n", + " pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,), dynamic_shapes=dynamic_shapes).module()\n", " vae = pipe.vae\n", - " transformer = torch.export.export_for_training(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs)).module()\n", + " transformer = torch.export.export_for_training(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs), dynamic_shapes=dynamic_shapes_transformer).module()\n", "models_dict = {}\n", "models_dict[\"transformer\"] = transformer\n", "models_dict[\"vae\"] = vae\n", @@ -765,7 +776,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, From 93955ef803de3a559f6916d4ba0f6a192dd23609 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 15 Jan 2025 14:53:44 +0400 Subject: [PATCH 5/7] black reformat --- .../stable-diffusion-v3-torch-fx.ipynb | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb index eba0a11081a..41f43b6e715 100644 --- a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb +++ b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb @@ -296,15 +296,15 @@ "unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n", "unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n", "\n", - "#Feature map height and width are dynamic\n", - "fm_height = torch.export.Dim('fm_height', min=16, max=256)\n", - "fm_width = torch.export.Dim('fm_width', min=16, max=256)\n", - "dim = torch.export.Dim('dim', min=1, max=16)\n", - "fm_height = 16*dim\n", - "fm_width = 16*dim\n", + "# Feature map height and width are dynamic\n", + "fm_height = torch.export.Dim(\"fm_height\", min=16, max=256)\n", + "fm_width = torch.export.Dim(\"fm_width\", min=16, max=256)\n", + "dim = torch.export.Dim(\"dim\", min=1, max=16)\n", + "fm_height = 16 * dim\n", + "fm_width = 16 * dim\n", "\n", "dynamic_shapes = {\"sample\": {2: fm_height, 3: fm_width}}\n", - "#iterate through the unet kwargs and set only hidden state kwarg to dynamic\n", + "# iterate through the unet kwargs and set only hidden state kwarg to dynamic\n", "dynamic_shapes_transformer = {key: (None if key != \"hidden_states\" else {2: fm_height, 3: fm_width}) for key in unet_kwargs.keys()}\n", "\n", "with torch.no_grad():\n", @@ -322,7 +322,9 @@ " pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,), dynamic_shapes=dynamic_shapes).module()\n", " pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,), dynamic_shapes=dynamic_shapes).module()\n", " vae = pipe.vae\n", - " transformer = torch.export.export_for_training(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs), dynamic_shapes=dynamic_shapes_transformer).module()\n", + " transformer = torch.export.export_for_training(\n", + " pipe.transformer.eval(), args=(), kwargs=(unet_kwargs), dynamic_shapes=dynamic_shapes_transformer\n", + " ).module()\n", "models_dict = {}\n", "models_dict[\"transformer\"] = transformer\n", "models_dict[\"vae\"] = vae\n", From 1813c2f517463bd77d19464db0fa38c39070f7b4 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Thu, 23 Jan 2025 11:21:54 +0400 Subject: [PATCH 6/7] fix issue --- .../stable-diffusion-v3-torch-fx.ipynb | 906 +++++++++++++++++- 1 file changed, 886 insertions(+), 20 deletions(-) diff --git a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb index 5ba3a57e2f6..2a0ca9ea9d7 100644 --- a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb +++ b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -146,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "editable": true, "slideshow": { @@ -157,7 +157,87 @@ "get_sd3_pipeline()": "get_sd3_pipeline(\"katuni4ka/tiny-random-sd3\")" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/user/Downloads/ov_notebooks_sd3/openvino_notebooks/.venv/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.\n", + " deprecate(\"Transformer2DModelOutput\", \"1.0.0\", deprecation_message)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "987b12d03d8246e8a94a3ba253a321aa", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/7 [00:00" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import numpy as np\n", "import torch\n", @@ -232,9 +1017,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e5506d8813f14ff1b698b3a4aefe5544", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from notebook_utils import device_widget\n", "\n", @@ -263,7 +1064,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "editable": true, "slideshow": { @@ -278,7 +1079,15 @@ "torch.ones((2, 2048))": "torch.ones((2, 64))" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino\n" + ] + } + ], "source": [ "import torch\n", "from nncf.torch.dynamic_graph.patch_pytorch import disable_patching\n", @@ -357,9 +1166,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0006d8aeafad4b13b9d97664b5b84132", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Checkbox(value=True, description='Quantization')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from notebook_utils import quantization_widget\n", "\n", @@ -417,7 +1242,50 @@ "pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512)": "pipe(prompt, num_inference_steps=num_inference_steps)" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1536 24 1536 23\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f5368c09da2a4ec48b6a95a07568df7e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/7 [00:00 Date: Thu, 23 Jan 2025 11:25:05 +0400 Subject: [PATCH 7/7] clean output --- .../stable-diffusion-v3-torch-fx.ipynb | 890 +----------------- 1 file changed, 11 insertions(+), 879 deletions(-) diff --git a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb index 2a0ca9ea9d7..c665bff3581 100644 --- a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb +++ b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb @@ -146,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "editable": true, "slideshow": { @@ -157,87 +157,7 @@ "get_sd3_pipeline()": "get_sd3_pipeline(\"katuni4ka/tiny-random-sd3\")" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/user/Downloads/ov_notebooks_sd3/openvino_notebooks/.venv/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.\n", - " deprecate(\"Transformer2DModelOutput\", \"1.0.0\", deprecation_message)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "987b12d03d8246e8a94a3ba253a321aa", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading pipeline components...: 0%| | 0/7 [00:00" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import numpy as np\n", "import torch\n", @@ -1017,25 +232,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e5506d8813f14ff1b698b3a4aefe5544", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from notebook_utils import device_widget\n", "\n", @@ -1064,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "editable": true, "slideshow": { @@ -1079,15 +278,7 @@ "torch.ones((2, 2048))": "torch.ones((2, 64))" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "from nncf.torch.dynamic_graph.patch_pytorch import disable_patching\n", @@ -1166,25 +357,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0006d8aeafad4b13b9d97664b5b84132", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Checkbox(value=True, description='Quantization')" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from notebook_utils import quantization_widget\n", "\n", @@ -1242,50 +417,7 @@ "pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512)": "pipe(prompt, num_inference_steps=num_inference_steps)" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1536 24 1536 23\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f5368c09da2a4ec48b6a95a07568df7e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading pipeline components...: 0%| | 0/7 [00:00