diff --git a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb index 0298e07882f..c665bff3581 100644 --- a/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb +++ b/notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb @@ -296,6 +296,17 @@ "unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n", "unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n", "\n", + "# Feature map height and width are dynamic\n", + "fm_height = torch.export.Dim(\"fm_height\", min=16, max=256)\n", + "fm_width = torch.export.Dim(\"fm_width\", min=16, max=256)\n", + "dim = torch.export.Dim(\"dim\", min=1, max=16)\n", + "fm_height = 16 * dim\n", + "fm_width = 16 * dim\n", + "\n", + "dynamic_shapes = {\"sample\": {2: fm_height, 3: fm_width}}\n", + "# iterate through the unet kwargs and set only hidden state kwarg to dynamic\n", + "dynamic_shapes_transformer = {key: (None if key != \"hidden_states\" else {2: fm_height, 3: fm_width}) for key in unet_kwargs.keys()}\n", + "\n", "with torch.no_grad():\n", " with disable_patching():\n", " text_encoder = torch.export.export_for_training(\n", @@ -308,10 +319,12 @@ " args=(text_encoder_input,),\n", " kwargs=(text_encoder_kwargs),\n", " ).module()\n", - " pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,)).module()\n", - " pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,)).module()\n", + " pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,), dynamic_shapes=dynamic_shapes).module()\n", + " pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,), dynamic_shapes=dynamic_shapes).module()\n", " vae = pipe.vae\n", - " transformer = torch.export.export_for_training(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs)).module()\n", + " transformer = torch.export.export_for_training(\n", + " pipe.transformer.eval(), args=(), kwargs=(unet_kwargs), dynamic_shapes=dynamic_shapes_transformer\n", + " ).module()\n", "models_dict = {}\n", "models_dict[\"transformer\"] = transformer\n", "models_dict[\"vae\"] = vae\n", @@ -450,8 +463,6 @@ " ).shuffle(seed=42)\n", "\n", " transformer_config = dict(pipe.transformer.config)\n", - " if \"model\" in transformer_config:\n", - " del transformer_config[\"model\"]\n", " wrapped_unet = UNetWrapper(pipe.transformer.model, transformer_config)\n", " pipe.transformer = wrapped_unet\n", " # Run inference for data collection\n", @@ -517,10 +528,10 @@ "if to_quantize:\n", " with disable_patching():\n", " with torch.no_grad():\n", - " nncf.compress_weights(text_encoder)\n", - " nncf.compress_weights(text_encoder_2)\n", - " nncf.compress_weights(vae_encoder)\n", - " nncf.compress_weights(vae_decoder)\n", + " text_encoder = nncf.compress_weights(text_encoder)\n", + " text_encoder_2 = nncf.compress_weights(text_encoder_2)\n", + " vae_encoder = nncf.compress_weights(vae_encoder)\n", + " vae_decoder = nncf.compress_weights(vae_decoder)\n", " quantized_transformer = nncf.quantize(\n", " model=original_transformer,\n", " calibration_dataset=nncf.Dataset(unet_calibration_data),\n", @@ -766,7 +777,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" },