diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml
index 491d275173..ae28b10c53 100644
--- a/.github/workflows/ci-checks.yml
+++ b/.github/workflows/ci-checks.yml
@@ -16,14 +16,14 @@ jobs:
# actions-ref: main
check-schema:
- uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.10.1
+ uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.0
with:
azure-dir: ".azure"
check-package:
- uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.10.1
+ uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.0
with:
- actions-ref: v0.10.1
+ actions-ref: v0.11.0
import-name: "thunder"
artifact-name: dist-packages-${{ github.sha }}
testing-matrix: |
diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml
index 34751f2795..cbd8fb2aa6 100644
--- a/.github/workflows/ci-testing.yml
+++ b/.github/workflows/ci-testing.yml
@@ -79,7 +79,8 @@ jobs:
- name: Install package & dependencies
run: |
pip --version
- pip install -e '.[test]' -U \
+ pip install -e . -U \
+ -r requirements/test.txt \
--find-links=${TORCH_URL} ${PIP_EXTRA_FLAG}
pip list
shell: bash
diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml
index a9c8d99b6f..d37e381b40 100644
--- a/.github/workflows/docs-build.yml
+++ b/.github/workflows/docs-build.yml
@@ -15,7 +15,7 @@ defaults:
jobs:
build-docs:
- uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.10.1
+ uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0
with:
python-version: "3.10"
requirements-file: "requirements/docs.txt"
diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml
index e97fedf8e0..078f9e6066 100644
--- a/.github/workflows/release-pypi.yml
+++ b/.github/workflows/release-pypi.yml
@@ -27,7 +27,7 @@ jobs:
# We do this, since failures on test.pypi aren't that bad
- name: Publish to Test PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
- uses: pypa/gh-action-pypi-publish@v1.8.12
+ uses: pypa/gh-action-pypi-publish@v1.8.14
with:
user: __token__
password: ${{ secrets.test_pypi_password }}
@@ -35,7 +35,7 @@ jobs:
- name: Publish distribution 📦 to PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
- uses: pypa/gh-action-pypi-publish@v1.8.12
+ uses: pypa/gh-action-pypi-publish@v1.8.14
with:
user: __token__
password: ${{ secrets.pypi_password }}
diff --git a/README.md b/README.md
index b119d4860b..7d60063864 100644
--- a/README.md
+++ b/README.md
@@ -1,31 +1,94 @@
+
+
+
+
+
+**Make PyTorch models Lightning fast.**
+
+______________________________________________________________________
+
+
+ Lightning.ai •
+ Performance •
+ Get started •
+ Install •
+ Examples •
+ Features •
+ Documentation •
+
+
+[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE)
+[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml)
+[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml)
+[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest)
+[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main)
+
+
+
# Welcome to ⚡ Lightning Thunder
-Lightning Thunder is a deep learning compiler for PyTorch. It makes PyTorch programs faster both on single accelerators or in distributed settings.
+**Thunder makes PyTorch models Lightning fast.**
+
+Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8).
+
+Works on single accelerators and in multi-GPU settings.
+Thunder aims to be usable, understandable, and extensible.
+
+## Performance
+
+Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).
+
+
+
+
+
+Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8.
+
+Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway).
-The main goal for Lightning Thunder is to allow optimizing user programs in the most extensible and expressive way possible.
+
+
+
-**NOTE: Lightning Thunder is alpha and not ready for production runs.** Feel free to get involved, expect a few bumps along the way.
+**NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way.
+
+## Get started
+
+Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial).
## Install Thunder
-Install the nvFuser nightly, which will also install the matching PyTorch nightly:
+Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, and Thunder together
```bash
+# install nvFuser which installs the matching nightly PyTorch
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
+
+# install thunder
+pip install lightning-thunder
```
-Install Thunder:
+
+ Advanced install options
+
+
+### Install from main
```bash
pip install git+https://github.com/Lightning-AI/lightning-thunder.git
```
-or install from the local repo:
+### Install to tinker and contribute
+
+Install this way to tinker with the internals and contribute:
```bash
-pip install .
+pip install -e .
```
+
+
+
## Hello World
Here is a simple example of how Thunder lets you compile and run PyTorch code:
@@ -56,11 +119,11 @@ print(result)
The compiled function `jfoo` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of larger PyTorch programs.
-## Running training
+## Train models
-Thunder is in its early stages, it should not be used for production runs yet.
+Thunder is in its early stages and should not be used for production runs yet.
-However, it can already deliver outstanding performance on models supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama2, Gemma, Falcon, and derivatives.
+However, it can already deliver outstanding performance on LLM model supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others.
Run training loop for Llama, single-GPU:
@@ -76,25 +139,25 @@ python examples/lit-gpt/train_fsdp.py
See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder.
-## What's in the box
+## Features
-Given a program, Thunder can generate an optimized program that:
+Given a Python callable or PyTorch module, Thunder can generate an optimized program that:
-- computes its forward and backward passes
-- coalesces operations into efficient fusion regions
-- dispatches computations to optimized kernels
-- distributes computations optimally across machines
+- Computes its forward and backward passes
+- Coalesces operations into efficient fusion regions
+- Dispatches computations to optimized kernels
+- Distributes computations optimally across machines
To do so, Thunder ships with:
-- a JIT for acquiring Python programs targeting PyTorch and custom operations
-- a multi-level IR to represent them as a trace of a reduced op-set
-- an extensible set of transformations on the trace, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`)
-- a way to dispatch operations to an extensible collection of executors
+- A JIT for acquiring Python programs targeting PyTorch and custom operations
+- A multi-level IR to represent operations as a trace of a reduced op-set
+- An extensible set of transformations on the trace, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`)
+- A way to dispatch operations to an extensible collection of executors
Thunder is written entirely in Python. Even its trace is represented as valid Python at all stages of transformation. This allows unprecedented levels of introspection and extensibility.
-Thunder doesn't generate device code. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like:
+Thunder doesn't generate code for accelerators directly. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like:
- [torch.compile](https://pytorch.org/get-started/pytorch-2.0/)
- [nvFuser](https://github.com/NVIDIA/Fuser)
@@ -106,7 +169,7 @@ Thunder doesn't generate device code. It acquires and transforms user programs s
Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations.
-## Build the documentation
+## Documentation
Docs are currently not hosted publicly. However you can build them locally really quickly:
@@ -141,9 +204,4 @@ Thunder is very thoroughly tested, so expect this to take a while.
## License
Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license.
-See LICENSE file for details.
-
-[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml)
-[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml)
-[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest)
-[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg?badge_token=mqheL1-cTn-280Vx4cJUdg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main?badge_token=mqheL1-cTn-280Vx4cJUdg)
+See the [LICENSE](LICENSE) file for details.
diff --git a/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png b/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png
new file mode 100644
index 0000000000..831a23e233
Binary files /dev/null and b/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png differ
diff --git a/docs/source/_static/images/normalized_training_throughput_zero2.png b/docs/source/_static/images/normalized_training_throughput_zero2.png
new file mode 100644
index 0000000000..be6e5888c3
Binary files /dev/null and b/docs/source/_static/images/normalized_training_throughput_zero2.png differ
diff --git a/docs/source/_static/images/training_throughput_single.png b/docs/source/_static/images/training_throughput_single.png
new file mode 100644
index 0000000000..6c0a7029a4
Binary files /dev/null and b/docs/source/_static/images/training_throughput_single.png differ
diff --git a/docs/source/advanced/inside_thunder.rst b/docs/source/advanced/inside_thunder.rst
index 36013d66ae..2ae68509d8 100644
--- a/docs/source/advanced/inside_thunder.rst
+++ b/docs/source/advanced/inside_thunder.rst
@@ -8,9 +8,9 @@ Bytecode interpretation
Thunder's interpreter works by:
-1. disassembling the PyTorch module or function into CPython bytecode
-2. interpreting the bytecode using an extended Python interpreter
-3. generating a sequential trace of operations on tensors and numbers
+1. Disassembling the PyTorch module or function into CPython bytecode
+2. Interpreting the bytecode using an extended Python interpreter
+3. Generating a sequential trace of operations on tensors and numbers
Representing Operations
=======================
diff --git a/docs/source/basic/overview.rst b/docs/source/basic/overview.rst
index cb272c0b9a..c456635e62 100644
--- a/docs/source/basic/overview.rst
+++ b/docs/source/basic/overview.rst
@@ -3,7 +3,7 @@ Thunder Overview
This section introduces Thunder's core concepts and architecture. For more details, see :doc:`Inside thunder <../advanced/inside_thunder>`.
-Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must be “valid” - it must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*.
+Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*.
This translation begins with::
@@ -13,7 +13,7 @@ or::
jitted_fn = thunder.jit(my_function)
-When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a jitted function.
+When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a function that when called will jit compile a path through the original function given information about the inputs.
When the jitted module or function is called::
@@ -23,22 +23,23 @@ or::
jitted_fn(*args, **kwargs)
-Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but this is actually required to produce a trace. Different inputs can produce different traces, since the operations called may different based on the properties of the input.
-The trace is generated by running the bytecode through an extensible Python interpreter implemented in Python itself, that can be extended to perform instructions in a different way compared to what standard CPython does. As such, it can be instrumented to construct a trace of operations performed on tensors or numbers, and keep track of the provenance of all objects being part of the program.
+As suggested above, Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but since control flow (and therefore the operations captured) may vary depending on the input, this is actually required to produce a trace. These traces are cached, so that if inputs of the same type, shape, etc are used again, the trace can be reused.
-If replacing CPython with Python itself sounds problematic from a performance perspective, keep in mind that the initial interpretation of a deep learning program is typically amortized during the subsequent interpretations, due to the iterative nature of deep learning programs. In other words, if the meta data of inputs (like tensor shape) doesn't change and control-flow conditions are unchanged, then there's no point in constructing a new trace, and we can rely on smart caching to just execute a trace right away.
+Traces are generated by running the bytecode through a custom Python interpreter, which is itself implemented in Python. This interpreter has been extended to perform instructions in a different way compared to what standard CPython does. In particular, it constructs a trace of operations performed on tensors or numbers, and keeps track of the provenance of all objects in the program, whether they originated from inside the interpreter or outside.
-Traces don't typically deal with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators, but it records the operators along one path of the traceable function into the trace.
+Much like other machine learning frameworks, Traces don't typically deal directly with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators. Instead, it records the operators along one path of the traceable function.
-Traces can be transformed (like for backward) and optimized (like by replacing calls to PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process see the :doc:`thunder step by step ` section.
+If replacing CPython with an interpreter written in Python sounds problematic from a performance perspective, you would be largely correct. We haven't yet put any time into optimizing it, and we think it consumes roughly 400x as much CPU time as CPython. However, the function only needs to be jitted once per equivalence class of inputs, and CPU is not a bottleneck in most machine learning pipelines. As long as the metadata of the inputs (such as a tensor's shape) and control flow conditions are not changed, we can rely on smart caching to immediately execute an optimized trace. The end result is a faster total execution time.
+
+Traces can be transformed (like for ``backward()``) and optimized (like by replacing calls to eager PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process, see the :doc:`thunder step by step ` section.
To recap, the complete translation process is:
-- For PyTorch modules, a Thunder-optimized module is created from the original module
-- For PyTorch functions, compilation produces a compiled function
-- When the module or function is called, the trace is generated, swapping some inputs with “proxies”
-- The trace is transformed and optimized to produce an execution trace
-- The execution trace is converted into a Python function and called
+- For PyTorch modules, a Thunder-optimized module is created from the original module.
+- For PyTorch functions, compilation produces a compiled function.
+- When the module or function is called, the trace is generated, swapping some inputs with “proxies”.
+- The trace is transformed and optimized to produce an execution trace.
+- The execution trace is converted into a Python function and called.
-As mentioned above, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times.
+As mentioned, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times.
diff --git a/docs/source/basic/sharp_edges.rst b/docs/source/basic/sharp_edges.rst
index bf2b0abf16..d62590316b 100644
--- a/docs/source/basic/sharp_edges.rst
+++ b/docs/source/basic/sharp_edges.rst
@@ -10,12 +10,6 @@ Inplace operations
Inplace PyTorch operations like `t.add_(1.0)` are not supported in Thunder yet. Support for inplace operations is coming soon.
-Complex control flow
---------------------
-
-Control flow is supported in Thunder, but certain constructs might still be unsupported.
-
-In particular, attributes need to be resolved at tracing time for control flow to work. Data-dependent control flow, that is, when a condition depends on the value of tensors rather than its meta-data like shape or type, is currently not supported.
Tensor subclasses
-----------------
@@ -24,13 +18,14 @@ Thunder currently supports Python data types and PyTorch tensors as inputs of fu
Subclasses of these types, e.g. lazy tensors, nested tensors, or sparse tensors are not supported today.
+
Tracing Python builtins, standard library operations and functions that call other languages
--------------------------------------------------------------------------------------------
Calling a Python builtin, standard library operation, or a function that calls into another language is safe to trace, so long as the following rules are observed:
-1. The function must not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement.
-2. The function must not manipulate tensor metadata or data. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing.
+1. The function should not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement.
+2. The function must not manipulate tensor data or metadata. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. To implement such operations, see :doc:`Adding Custom Operators <../notebooks/adding_custom_operator>`
3. The function must not produce different results across invocations. Again, since the operation won't appear in traces, Thunder cannot replicate an operation that produces different results when it's invoked, like ``random.random()`` will.
..
diff --git a/docs/source/conf.py b/docs/source/conf.py
index dcf2414b54..052fa2437c 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -49,7 +49,10 @@
github_user = "Lightning-AI"
github_repo = project
-linkcheck_ignore = [rf"https://github.com/Lightning-AI/lightning-thunder(/.*|\.git)"]
+linkcheck_ignore = [
+ rf"https://github.com/Lightning-AI/lightning-thunder(/.*|\.git)",
+ rf"https://github.com/Lightning-AI/.*/blob/.*#.*", # github anchors are tricky
+]
# -- Project documents -------------------------------------------------------
diff --git a/docs/source/fundamentals/installation.rst b/docs/source/fundamentals/installation.rst
index 7ee759ea82..8d41a24047 100644
--- a/docs/source/fundamentals/installation.rst
+++ b/docs/source/fundamentals/installation.rst
@@ -56,11 +56,11 @@ Thunder can easily integrate OpenAI Triton kernels. You can install Triton using
Install Thunder
===============
-You can now install Thunder
+You can now install Thunder::
pip install git+https://github.com/Lightning-AI/lightning-thunder.git
-Alternatively you can clone the Thunder repository and install locally
+Alternatively you can clone the Thunder repository and install locally::
git clone https://github.com/Lightning-AI/lightning-thunder.git
cd lightning-thunder
diff --git a/docs/source/intermediate/additional_executors.rst b/docs/source/intermediate/additional_executors.rst
index 76bb270bf0..911d7f16e9 100644
--- a/docs/source/intermediate/additional_executors.rst
+++ b/docs/source/intermediate/additional_executors.rst
@@ -10,11 +10,9 @@ Triton CrossEntropy Executor
The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example::
+ import torch
import thunder
- from thunder.executors import nvfuserex, torchex
- from thunder.executors.triton_crossentropy import deregister_triton_entropyex, register_triton_entropyex
-
- register_triton_entropyex(add_to_default_executors=False)
+ from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex
def xentropy(logits, labels, weight, reduction, ignore_index):
return thunder.torch.cross_entropy(
@@ -23,7 +21,7 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an
jitted_xentropy = thunder.jit(
xentropy,
- executors_list=['triton_crossentropy', nvfuserex, torchex]
+ executors=[triton_cross_entropy_ex,]
)
device = 'cuda'
@@ -41,43 +39,42 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an
This prints::
- # Constructed by Delete Last Used
+ # Constructed by Delete Last Used (took 0 milliseconds)
import torch
+ from thunder.executors.torchex import no_autocast
+
@torch.no_grad()
- def xentropy(logits, labels, weight, reduction, ignore_index):
+ @no_autocast()
+ def computation(logits, labels, weight):
# logits: "cuda:0 f32[2048, 50257]"
# labels: "cuda:0 i64[2048]"
# weight: "cuda:0 f32[50257]"
- # "sum"
- # ignore_index: "int 10106"
- t22 = triton_cross_entropy(logits, labels, weight, None, ignore_index, None, "sum", 0.0) # t22: "cuda:0 f32[]"
- del [logits, labels, weight, ignore_index]
- return t22
+ t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]"
+ del logits, labels, weight
+ return t23
-As shown in the above trace, ``triton_cross_entropy()`` is the one running the operation.
+As shown in the above trace, ``triton_crossentropy()`` is the one running the operation.
Apex CrossEntropy Executor
==========================
The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this::
+ import torch
import thunder
- from thunder.executors import nvfuserex, torchex
- from thunder.executors.apex_entropyex import deregister_apex_entropyex, register_apex_entropyex
-
- register_apex_entropyex(add_to_default_executors=False)
+ from thunder.executors.apex_entropyex import apex_ex
def xentropy(logits, labels):
return thunder.torch.cross_entropy(
logits, labels, reduction='mean', ignore_index=-1
)
- jitted_xentropy = thunder.jit(xentropy, executors_list=['apex_xentropy', nvfuserex, torchex])
+ jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,])
device = 'cuda'
dtype = torch.float32
- logits = torch.randn([2048, 50257], device=device, dtype=thunder.torch.to_torch_dtype(dtype))
+ logits = torch.randn([2048, 50257], device=device, dtype=dtype)
labels = torch.randint(0, 50257, [2048], device=device)
jitted_xentropy(logits, labels)
@@ -86,14 +83,17 @@ The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an
This prints::
- # Constructed by Delete Last Used
+ # Constructed by Delete Last Used (took 0 milliseconds)
import torch
+ from thunder.executors.torchex import no_autocast
+
@torch.no_grad()
- def xentropy(logits, labels):
+ @no_autocast()
+ def computation(logits, labels):
# logits: "cuda:0 f32[2048, 50257]"
# labels: "cuda:0 i64[2048]"
- t18 = apex_cross_entropy(logits, labels, None, None, -1, None, "mean", 0.0) # t18: "cuda:0 f32[]"
- del [logits, labels]
+ (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0)
+ del logits, labels
return t18
showing that Apex is running the operation.
diff --git a/examples/lit-gpt/test_parametrized.py b/examples/lit-gpt/test_parametrized.py
index bca55173fa..5e658b6447 100644
--- a/examples/lit-gpt/test_parametrized.py
+++ b/examples/lit-gpt/test_parametrized.py
@@ -7,12 +7,13 @@
MID_BENCHMARK_OUT - use this env variable to control whether you want to see the combined results
between each test.
BENCHMARK_OUT_FORMAT - use this env variable to control the format in which the results are presented.
- Uses 'xlsx' by default. More format support to come soon.
+ Uses 'xlsx' by default. Supported: 'none', 'print', 'xlsx'.
'''
import torch
from absl.testing import parameterized
from absl.testing import absltest
+from collections import defaultdict
import os
import subprocess
import json
@@ -48,6 +49,9 @@ def add_to_dataframe(self):
self.dataframe_data.append(self.perf_metrics_dict)
def complete_dataframe(self, is_teardown):
+ if not self.dataframe_data:
+ # The benchmark probably failed
+ return
#Called when tearing down the parametrized test
#This generates a summarized dataframe for each perf metric and saves as a xlsx file
df = pd.DataFrame(self.dataframe_data)
@@ -59,7 +63,7 @@ def complete_dataframe(self, is_teardown):
self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index()
self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index()
- if self.output_format not in ('none', 'print'):
+ if self.output_format == "xlsx":
output_ext = {'xlsx': '.xlsx', }[self.output_format]
if not is_teardown:
filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext)
@@ -84,7 +88,6 @@ def complete_dataframe(self, is_teardown):
print(self.memory_used_GB_df)
def run_benchmark(self, kwargs):
- # benchmark_file = 'thunder/benchmarks/benchmark_litgpt.py'
command_list = []
for key, val in kwargs.items():
command_list.append("--" + str(key) + "=" + str(val))
@@ -98,32 +101,26 @@ def run_benchmark(self, kwargs):
print(f'Running {" ".join(subprocess_cmd)!r}')
proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True)
+
+ self.perf_metrics_dict = {}
+ if os.path.exists(self.json_file_path):
+ with open(self.json_file_path, 'r') as file:
+ self.perf_metrics_dict = json.load(file)
+ # Cleanup after the benchmark finishes. It might have failed before creating this
+ os.remove(self.json_file_path)
+
if proc_output.returncode:
- print(proc_output.stdout)
- print(proc_output.stderr)
- proc_output.check_returncode()
-
- with open(self.json_file_path, 'r') as file:
- self.perf_metrics_dict = json.load(file)
- os.remove(self.json_file_path) #cleanup after test finishes
-
- if self.perf_metrics_dict['average_iter_time'] is None:
- if 'CUDA out of memory' in proc_output.stdout:
- self.perf_metrics_dict['average_iter_time'] = 'OOM'
- self.perf_metrics_dict['model_flops'] = 'OOM'
- self.perf_metrics_dict['model_flop_per_sec'] = 'OOM'
- self.perf_metrics_dict['tokens_per_sec'] = 'OOM'
- self.perf_metrics_dict['tokens_per_sec_per_gpu'] = 'OOM'
- self.perf_metrics_dict['memory_used_GB'] = 'OOM'
+ if 'CUDA out of memory' in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr:
+ defaultdict_oom = defaultdict(lambda: "OOM")
+ defaultdict_oom.update(self.perf_metrics_dict)
+ self.perf_metrics_dict = defaultdict_oom
pass_str = "TestCase did not finish reporting metrics due to CUDA out of memory error. Reporting OOM and triggering test success."
return True, pass_str
- else:
- print(proc_output.stdout)
- print(proc_output.stderr)
- fail_str = "Testcase did not finish reporting metrics due to an unknown error. Triggering test failure."
- return False, fail_str
- else:
- return True, "Test passed successfully."
+ print(proc_output.stdout)
+ print(proc_output.stderr)
+ fail_str = "TestCase did not finish reporting metrics due to an unknown error. Triggering test failure."
+ return False, fail_str
+ return True, "Test passed successfully."
class Test(parameterized.TestCase):
diff --git a/examples/llama2.c/README.md b/examples/llama2.c/README.md
index 5acb8f0742..4a999da840 100644
--- a/examples/llama2.c/README.md
+++ b/examples/llama2.c/README.md
@@ -28,9 +28,9 @@ The code is configured to run with Thunder by default.
Results with 1 GPU:
-- ~339 ms/iter (torch.compile 'inductor')
-- ~347 ms/iter (thunder nvfuser)
-- ~431 ms/iter (eager)
+- ~215 ms/iter (torch.compile 'inductor')
+- ~239 ms/iter (thunder nvfuser)
+- ~339 ms/iter (eager)
CUDAGraphs are not used as the results were worse with them.
@@ -46,15 +46,14 @@ nanoGPT doesn't implement KV caching so this is expectedly slow. Please checkout
## Setup
```text
-Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
+Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Is debug build: False
-CUDA used to build PyTorch: 12.1
-CUDA runtime version: 12.1.105
+CUDA used to build PyTorch: 12.4
+CUDA runtime version: 12.4.99
GPU 0: NVIDIA A100-SXM4-40GB
-Nvidia driver version: 525.125.06
+Nvidia driver version: 550.54.14
-pytorch-triton @ https://download.pytorch.org/whl/nightly/pytorch_triton-3.0.0%2B901819d2b6-cp310-cp310-linux_x86_64.whl
-torch @ https://download.pytorch.org/whl/nightly/cu121/torch-2.3.0.dev20240130%2Bcu121-cp310-cp310-linux_x86_64.whl
-lightning-thunder==8b107c6fe531c94c6705dbf39700863685ba5b65
-nvfuser_cu121==0.1.5.dev20240131
+triton == 3.0.0
+torch == 2.4.0a0+git685ace3
+nvfuser @ 0.2.0+git70101da
```
diff --git a/examples/llama2.c/sample.py b/examples/llama2.c/sample.py
index b8ccacfa48..094b6c3f74 100644
--- a/examples/llama2.c/sample.py
+++ b/examples/llama2.c/sample.py
@@ -55,7 +55,7 @@
from thunder.executors.sdpaex import sdpa_ex
executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]
- cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors_list=executors)
+ cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors=executors)
# the generate implementation is not compile friendly, so bind the compiled model to the generate implementation
generate = partial(Transformer.generate, cmodel)
# workaround for "Foward nn.Module attributes through the ThunderOptimizedModule"
diff --git a/examples/llama2.c/train.py b/examples/llama2.c/train.py
index 58d88d4729..206a4e065d 100644
--- a/examples/llama2.c/train.py
+++ b/examples/llama2.c/train.py
@@ -70,8 +70,8 @@
warmup_iters = 1000 # how many steps to warm up for
# system
device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
-# dtype = "bfloat16" # float32|bfloat16|float16
-compile = "thunder" # eager|torch|thunder
+dtype = "bfloat16" # float32|bfloat16|float16
+compile = "thunder" # thunder|torch|eager
# -----------------------------------------------------------------------------
config_keys = [
k
@@ -122,8 +122,15 @@
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
-# ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
-ctx = nullcontext() # torch.amp.autocast(device_type=device_type, dtype=ptdtype)
+ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
+ctx = (
+ nullcontext()
+ if device_type == "cpu"
+ else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
+)
+# Disable other than FlashAttention backends for SDPA
+torch.backends.cuda.enable_math_sdp(False)
+torch.backends.cuda.enable_mem_efficient_sdp(False)
# task-specific setup
iter_batches = partial(
@@ -179,10 +186,11 @@
model.load_state_dict(state_dict)
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]
+
model.to(device)
# initialize a GradScaler. If enabled=False scaler is a no-op
-scaler = torch.cuda.amp.GradScaler(enabled=(False)) # dtype == "float16"))
+scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
@@ -190,19 +198,19 @@
optimizer.load_state_dict(checkpoint["optimizer"])
checkpoint = None # free up memory
-raw_model = eval_model = train_model = model
+raw_model = model
# wrap model into DDP container
if ddp:
if compile == "thunder":
from thunder.distributed import ddp
- train_model = ddp(train_model)
+ model = ddp(model)
else:
# Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
# construction time since NCCL does not support `ComplexFloat`
- train_model._ddp_params_and_buffers_to_ignore = {"freqs_cis"}
- train_model = DDP(train_model, device_ids=[ddp_local_rank])
+ model._ddp_params_and_buffers_to_ignore = {"freqs_cis"}
+ model = DDP(model, device_ids=[ddp_local_rank])
# compile the model
if compile == "thunder":
@@ -212,31 +220,29 @@
from thunder.executors.sdpaex import sdpa_ex
executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]
- eval_model = thunder.compile(eval_model.eval(), disable_torch_autograd_support=True, executors_list=executors)
- train_model = thunder.compile(train_model.train(), executors_list=executors)
+ model = thunder.jit(model, executors=executors)
elif compile == "torch":
print("compiling the model with torch... (takes a ~minute)")
- eval_model = torch.compile(eval_model)
- train_model = torch.compile(train_model)
+ model = torch.compile(model)
# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
out = {}
if compile != "thunder":
- eval_model.eval()
+ model.eval()
for split in ["train", "val"]:
batch_iter = iter_batches(split=split)
losses = torch.zeros(eval_iters) # keep on CPU
for k in range(eval_iters):
X, Y = next(batch_iter)
with ctx:
- logits = eval_model(X, Y)
+ logits = model(X, Y)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1)
losses[k] = loss.item()
out[split] = losses.mean()
if compile != "thunder":
- train_model.train()
+ model.train()
return out
# learning rate decay scheduler (cosine with warmup)
@@ -313,9 +319,9 @@ def get_lr(it):
# the official way to do this is with model.no_sync() context manager, but
# this forces us to repeat code.
# looking at the source of that context manager, it just toggles this variable
- train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
+ model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
with ctx:
- logits = train_model(X, Y)
+ logits = model(X, Y)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1)
loss = loss / gradient_accumulation_steps
# immediately async prefetch next batch while model is doing the forward pass on the GPU
@@ -325,7 +331,7 @@ def get_lr(it):
# clip the gradient
if grad_clip != 0.0:
scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(train_model.parameters(), grad_clip)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb
index 68f61a47a0..a1a888cc72 100644
--- a/notebooks/zero_to_thunder.ipynb
+++ b/notebooks/zero_to_thunder.ipynb
@@ -3,7 +3,11 @@
{
"cell_type": "markdown",
"id": "1638964c",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"# Zero to Thunder\n",
"\n",
@@ -21,16 +25,18 @@
"source": [
"import sys\n",
"sys.path.insert(0, '..')\n",
- "import inspect\n",
- "\n",
"\n",
- "import torch, thunder\n"
+ "import torch, thunder"
]
},
{
"cell_type": "markdown",
"id": "54f87aba",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"## Compiling a first module with Thunder\n",
"\n",
@@ -40,7 +46,7 @@
{
"cell_type": "code",
"execution_count": 2,
- "id": "d6ca6328",
+ "id": "892be718",
"metadata": {},
"outputs": [
{
@@ -62,26 +68,26 @@
" self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n",
" self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n",
" self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)\n",
- "\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" x_fc_1 = self.fc_1(x)\n",
" x_fc_2 = self.fc_2(x)\n",
" x = torch.nn.functional.silu(x_fc_1) * x_fc_2\n",
" return self.proj(x)\n",
- "\n",
- "\n",
"with torch.device(\"cuda\"):\n",
" m = LLaMAMLP(4096, 11008)\n",
"for p in m.parameters():\n",
" p.requires_grad_(False)\n",
- "\n",
- "print(m)"
+ "print(m)\n"
]
},
{
"cell_type": "markdown",
"id": "702ea054",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`"
]
@@ -125,8 +131,12 @@
},
{
"cell_type": "markdown",
- "id": "59db20f6",
- "metadata": {},
+ "id": "47d24f2d-0e89-4fe8-8154-9b50f2633e1b",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance."
]
@@ -135,15 +145,19 @@
"cell_type": "code",
"execution_count": 5,
"id": "7f4de1b3",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "-"
+ }
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"deviation: 1.4901161193847656e-07\n",
- "58.2 ms ± 306 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
- "58.7 ms ± 50.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
+ "61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
+ "62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
@@ -157,11 +171,14 @@
},
{
"cell_type": "markdown",
- "id": "8835543e",
- "metadata": {},
+ "id": "7996acc7-de20-4aa5-80f0-1ab6042e2650",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
- "So what has changed?\n",
- "Quite a bit!\n",
+ "So what has changed? Quite a bit!\n",
"\n",
"When we call the Thunder module, it do the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:"
]
@@ -170,7 +187,9 @@
"cell_type": "code",
"execution_count": 6,
"id": "a6f4b77c",
- "metadata": {},
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [
{
"data": {
@@ -221,8 +240,12 @@
},
{
"cell_type": "markdown",
- "id": "a0071924",
- "metadata": {},
+ "id": "2ef89186-70cd-4737-9695-ed282da2a56c",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "notes"
+ }
+ },
"source": [
"For more detail of what is going on in this trace:\n",
"- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.\n",
@@ -237,11 +260,17 @@
{
"cell_type": "markdown",
"id": "7749aed1",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"## Compiling a more complex model\n",
"\n",
- "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):"
+ "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n",
+ "\n",
+ "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt."
]
},
{
@@ -258,7 +287,7 @@
" (transformer): ModuleDict(\n",
" (wte): Embedding(32000, 4096)\n",
" (h): ModuleList(\n",
- " (0-3): 4 x Block(\n",
+ " (0-15): 16 x Block(\n",
" (norm_1): RMSNorm()\n",
" (attn): CausalSelfAttention(\n",
" (attn): Linear(in_features=4096, out_features=12288, bias=False)\n",
@@ -286,7 +315,8 @@
"from lit_gpt import GPT\n",
"from thunder.tests.lit_gpt_model import Config\n",
"cfg = Config.from_name('Llama-2-7b-hf')\n",
- "cfg.n_layer = 4 # fewer layers\n",
+ "cfg.n_layer = 16 # fewer layers\n",
+ "torch.set_default_dtype(torch.bfloat16)\n",
"with torch.device('cuda'):\n",
" m = GPT(cfg)\n",
"m\n"
@@ -310,7 +340,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "deviation: 1.8477439880371094e-06\n"
+ "deviation: 0.03125\n"
]
}
],
@@ -327,22 +357,37 @@
},
{
"cell_type": "markdown",
- "id": "2f681093",
- "metadata": {},
+ "id": "9947e8df-cd2d-447d-90b9-ee08bb5a9fb2",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
- "Just like before, we can see the program it ran:"
+ "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.\n",
+ "\n",
+ "Just like before, we can see the program it ran, it is a lot longer, though."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ac7e8bc9",
- "metadata": {},
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
{
"data": {
"text/plain": [
- "# Constructed by Delete Last Used (took 1 milliseconds)\n",
+ "# Constructed by Delete Last Used (took 10 milliseconds)\n",
"import torch\n",
"from torch import Tensor\n",
"import torch.nn.functional\n",
@@ -386,626 +431,2728 @@
" t31, \\\n",
" t32, \\\n",
" t33, \\\n",
+ " t34, \\\n",
+ " t35, \\\n",
+ " t36, \\\n",
+ " t37, \\\n",
+ " t38, \\\n",
+ " t39, \\\n",
+ " t40, \\\n",
+ " t41, \\\n",
+ " t42, \\\n",
+ " t43, \\\n",
+ " t44, \\\n",
+ " t45, \\\n",
+ " t46, \\\n",
+ " t47, \\\n",
+ " t48, \\\n",
+ " t49, \\\n",
+ " t50, \\\n",
+ " t51, \\\n",
+ " t52, \\\n",
+ " t53, \\\n",
+ " t54, \\\n",
+ " t55, \\\n",
+ " t56, \\\n",
+ " t57, \\\n",
+ " t58, \\\n",
+ " t59, \\\n",
+ " t60, \\\n",
+ " t61, \\\n",
+ " t62, \\\n",
+ " t63, \\\n",
+ " t64, \\\n",
+ " t65, \\\n",
+ " t66, \\\n",
+ " t67, \\\n",
+ " t68, \\\n",
+ " t69, \\\n",
+ " t70, \\\n",
+ " t71, \\\n",
+ " t72, \\\n",
+ " t73, \\\n",
+ " t74, \\\n",
+ " t75, \\\n",
+ " t76, \\\n",
+ " t77, \\\n",
+ " t78, \\\n",
+ " t79, \\\n",
+ " t80, \\\n",
+ " t81, \\\n",
+ " t82, \\\n",
+ " t83, \\\n",
+ " t84, \\\n",
+ " t85, \\\n",
+ " t86, \\\n",
+ " t87, \\\n",
+ " t88, \\\n",
+ " t89, \\\n",
+ " t90, \\\n",
+ " t91, \\\n",
+ " t92, \\\n",
+ " t93, \\\n",
+ " t94, \\\n",
+ " t95, \\\n",
+ " t96, \\\n",
+ " t97, \\\n",
+ " t98, \\\n",
+ " t99, \\\n",
+ " t100, \\\n",
+ " t101, \\\n",
+ " t102, \\\n",
+ " t103, \\\n",
+ " t104, \\\n",
+ " t105, \\\n",
+ " t106, \\\n",
+ " t107, \\\n",
+ " t108, \\\n",
+ " t109, \\\n",
+ " t110, \\\n",
+ " t111, \\\n",
+ " t112, \\\n",
+ " t113, \\\n",
+ " t114, \\\n",
+ " t115, \\\n",
+ " t116, \\\n",
+ " t117, \\\n",
" = args\n",
" del args\n",
- " t38 = torch.nn.functional.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t38 = ltorch.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t334 = ltorch.reshape(t0, [512]) # t334: \"cuda:0 i64[512]\"\n",
- " # t334 = prims.reshape(t0, (512,)) # t334: \"cuda:0 i64[512]\"\n",
- " # t335 = prims.take(t33, t334, 0) # t335: \"cuda:0 f32[512, 4096]\"\n",
- " # t38 = ltorch.reshape(t335, [1, 512, 4096]) # t38: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t38 = prims.reshape(t335, (1, 512, 4096)) # t38: \"cuda:0 f32[1, 512, 4096]\"\n",
- " t34 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t34: \"cuda:0 f32[512, 128]\"\n",
- " t35 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t35: \"cuda:0 f32[512, 128]\"\n",
- " t374 = torch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n",
- " # t374 = ltorch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n",
- " # t374 = prims.broadcast_in_dim(t17, [1, 4096], [1]) # t374: \"cuda:0 f32[1, 4096]\"\n",
- " t375 = torch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t375 = ltorch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t375 = prims.broadcast_in_dim(t374, [1, 1, 4096], [0, 2]) # t375: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t374\n",
- " t47 = Tensor.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t47 = ltorch.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t47 = prims.broadcast_in_dim(t375, (1, 512, 4096), (0, 1, 2)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t375\n",
- " t475 = torch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n",
- " # t475 = ltorch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n",
- " # t475 = prims.broadcast_in_dim(t24, [1, 4096], [1]) # t475: \"cuda:0 f32[1, 4096]\"\n",
- " t476 = torch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t476 = ltorch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t476 = prims.broadcast_in_dim(t475, [1, 1, 4096], [0, 2]) # t476: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t475\n",
- " t311 = Tensor.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t311 = ltorch.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t311 = prims.broadcast_in_dim(t476, (1, 512, 4096), (0, 1, 2)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t476\n",
- " t478 = torch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n",
- " # t478 = ltorch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n",
- " # t478 = prims.broadcast_in_dim(t16, [1, 4096], [1]) # t478: \"cuda:0 f32[1, 4096]\"\n",
- " t479 = torch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t479 = ltorch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t479 = prims.broadcast_in_dim(t478, [1, 1, 4096], [0, 2]) # t479: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t478\n",
- " t331 = Tensor.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t331 = ltorch.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t331 = prims.broadcast_in_dim(t479, (1, 512, 4096), (0, 1, 2)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t479\n",
- " t403 = torch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n",
- " # t403 = ltorch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n",
- " # t403 = prims.broadcast_in_dim(t21, [1, 4096], [1]) # t403: \"cuda:0 f32[1, 4096]\"\n",
- " t404 = torch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t404 = ltorch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t404 = prims.broadcast_in_dim(t403, [1, 1, 4096], [0, 2]) # t404: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t403\n",
- " t98 = Tensor.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t98 = ltorch.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t98 = prims.broadcast_in_dim(t404, (1, 512, 4096), (0, 1, 2)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t404\n",
- " t406 = torch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n",
- " # t406 = ltorch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n",
- " # t406 = prims.broadcast_in_dim(t18, [1, 4096], [1]) # t406: \"cuda:0 f32[1, 4096]\"\n",
- " t407 = torch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t407 = ltorch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t407 = prims.broadcast_in_dim(t406, [1, 1, 4096], [0, 2]) # t407: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t406\n",
- " t118 = Tensor.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t118 = ltorch.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t118 = prims.broadcast_in_dim(t407, (1, 512, 4096), (0, 1, 2)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t407\n",
- " t427 = torch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n",
- " # t427 = ltorch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n",
- " # t427 = prims.broadcast_in_dim(t22, [1, 4096], [1]) # t427: \"cuda:0 f32[1, 4096]\"\n",
- " t428 = torch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t428 = ltorch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t428 = prims.broadcast_in_dim(t427, [1, 1, 4096], [0, 2]) # t428: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t427\n",
- " t169 = Tensor.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t169 = ltorch.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t169 = prims.broadcast_in_dim(t428, (1, 512, 4096), (0, 1, 2)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t428\n",
- " t430 = torch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n",
- " # t430 = ltorch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n",
- " # t430 = prims.broadcast_in_dim(t19, [1, 4096], [1]) # t430: \"cuda:0 f32[1, 4096]\"\n",
- " t431 = torch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t431 = ltorch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t431 = prims.broadcast_in_dim(t430, [1, 1, 4096], [0, 2]) # t431: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t430\n",
- " t189 = Tensor.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t189 = ltorch.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t189 = prims.broadcast_in_dim(t431, (1, 512, 4096), (0, 1, 2)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t431\n",
- " t451 = torch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n",
- " # t451 = ltorch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n",
- " # t451 = prims.broadcast_in_dim(t23, [1, 4096], [1]) # t451: \"cuda:0 f32[1, 4096]\"\n",
- " t452 = torch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t452 = ltorch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t452 = prims.broadcast_in_dim(t451, [1, 1, 4096], [0, 2]) # t452: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t451\n",
- " t240 = Tensor.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t240 = ltorch.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t240 = prims.broadcast_in_dim(t452, (1, 512, 4096), (0, 1, 2)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t452\n",
- " t454 = torch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n",
- " # t454 = ltorch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n",
- " # t454 = prims.broadcast_in_dim(t20, [1, 4096], [1]) # t454: \"cuda:0 f32[1, 4096]\"\n",
- " t455 = torch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t455 = ltorch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n",
- " # t455 = prims.broadcast_in_dim(t454, [1, 1, 4096], [0, 2]) # t455: \"cuda:0 f32[1, 1, 4096]\"\n",
- " del t454\n",
- " t260 = Tensor.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t260 = ltorch.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t260 = prims.broadcast_in_dim(t455, (1, 512, 4096), (0, 1, 2)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t455\n",
- " t395 = torch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n",
- " # t395 = ltorch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n",
- " # t395 = prims.broadcast_in_dim(t34, [1, 512, 128], [1, 2]) # t395: \"cuda:0 f32[1, 512, 128]\"\n",
- " del t34\n",
- " t396 = torch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n",
- " # t396 = ltorch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n",
- " # t396 = prims.broadcast_in_dim(t395, [1, 1, 512, 128], [0, 2, 3]) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n",
- " del t395\n",
- " t63 = Tensor.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t63 = ltorch.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t63 = prims.broadcast_in_dim(t396, (1, 32, 512, 128), (0, 1, 2, 3)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t396\n",
- " t398 = torch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n",
- " # t398 = ltorch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n",
- " # t398 = prims.broadcast_in_dim(t35, [1, 512, 128], [1, 2]) # t398: \"cuda:0 f32[1, 512, 128]\"\n",
- " del t35\n",
- " t399 = torch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n",
- " # t399 = ltorch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n",
- " # t399 = prims.broadcast_in_dim(t398, [1, 1, 512, 128], [0, 2, 3]) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n",
- " del t398\n",
- " t65 = Tensor.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t65 = ltorch.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t65 = prims.broadcast_in_dim(t399, (1, 32, 512, 128), (0, 1, 2, 3)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t399\n",
- " [t44, t48] = nvFusion0(t38, t47)\n",
- " # t39 = prims.mul(t38, t38) # t39: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t40 = prims.sum(t39, (2,)) # t40: \"cuda:0 f32[1, 512]\"\n",
- " # t41 = prims.broadcast_in_dim(t40, [1, 512, 1], [0, 1]) # t41: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t42 = prims.div(t41, 4096.0) # t42: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t43 = prims.add(t42, 1e-05) # t43: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t44 = prims.rsqrt(t43) # t44: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t45 = prims.broadcast_in_dim(t44, (1, 512, 4096), (0, 1, 2)) # t45: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t46 = prims.mul(t38, t45) # t46: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t48 = prims.mul(t46, t47) # t48: \"cuda:0 f32[1, 512, 4096]\"\n",
- " t49 = torch.nn.functional.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t49 = ltorch.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t49 = prims.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n",
- " t50 = torch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t50 = ltorch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t50 = prims.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " del t49\n",
- " t51 = torch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t51 = ltorch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t51 = prims.transpose(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " del t50\n",
- " (t52, t53, t54) = torch.split(t51, (1, 1, 1), 2)\n",
- " # (t52, t53, t54) = ltorch.split(t51, (1, 1, 1), 2)\n",
- " # t52 = prims.slice_prim(t51, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t52: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t53 = prims.slice_prim(t51, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t53: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t54 = prims.slice_prim(t51, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t54: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " del t51\n",
- " t55 = torch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t55 = ltorch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t55 = prims.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t52\n",
- " t56 = torch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t56 = ltorch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t56 = prims.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t53\n",
- " t57 = torch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t57 = ltorch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t57 = prims.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t54\n",
- " t58 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t58: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " t68 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t68: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " t78 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t78: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " del t55\n",
- " t80 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t80: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " del t56\n",
- " t60 = torch_slice_prim_impl(t58, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t60: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t59 = torch_slice_prim_impl(t58, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t59: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t69 = torch_slice_prim_impl(t68, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t69: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t70 = torch_slice_prim_impl(t68, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t70: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " [t61, t71] = nvFusion1(t60, t70)\n",
- " # t61 = prims.neg(t60) # t61: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " # t71 = prims.neg(t70) # t71: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " del t60, t70\n",
- " t62 = torch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t62 = ltorch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t62 = prims.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t61, t59\n",
- " t72 = torch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t72 = ltorch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t72 = prims.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t71, t69\n",
- " [t67, t77] = nvFusion2(t58, t62, t63, t65, t68, t72)\n",
- " # t64 = prims.mul(t58, t63) # t64: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t66 = prims.mul(t62, t65) # t66: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t67 = prims.add(t64, t66) # t67: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t74 = prims.mul(t68, t63) # t74: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t76 = prims.mul(t72, t65) # t76: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t77 = prims.add(t74, t76) # t77: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t58, t62, t68, t72\n",
- " t79 = torch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t79 = ltorch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t79 = prims.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t67, t78\n",
- " t81 = torch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t81 = ltorch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t81 = prims.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t77, t80\n",
- " (t82, t83, t84, t85) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t79, t81, t57, None, 0.0, True, 0.08838834764831843)\n",
- " t86 = torch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t86 = ltorch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t86 = prims.transpose(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " t87 = torch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t87 = ltorch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t87 = prims.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t86\n",
- " t88 = torch.nn.functional.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t88 = ltorch.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t88 = prims.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t89, t95, t99] = nvFusion3(t38, t88, t98)\n",
- " # t89 = prims.add(t88, t38) # t89: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t90 = prims.mul(t89, t89) # t90: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t91 = prims.sum(t90, (2,)) # t91: \"cuda:0 f32[1, 512]\"\n",
- " # t92 = prims.broadcast_in_dim(t91, [1, 512, 1], [0, 1]) # t92: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t93 = prims.div(t92, 4096.0) # t93: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t94 = prims.add(t93, 1e-05) # t94: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t95 = prims.rsqrt(t94) # t95: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t97 = prims.mul(t89, t96) # t97: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t99 = prims.mul(t97, t98) # t99: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t88\n",
- " t101 = torch.nn.functional.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t101 = ltorch.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t101 = prims.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t100 = torch.nn.functional.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t100 = ltorch.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t100 = prims.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n",
- " [t107] = nvFusion4(t100, t101)\n",
- " # t102 = prims.neg(t100) # t102: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t103 = prims.exp(t102) # t103: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t104 = prims.add(1.0, t103) # t104: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t105 = prims.reciprocal(t104) # t105: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t106 = prims.mul(t100, t105) # t106: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t107 = prims.mul(t106, t101) # t107: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t108 = torch.nn.functional.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t108 = ltorch.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t108 = prims.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t109, t115, t119] = nvFusion5(t108, t118, t89)\n",
- " # t109 = prims.add(t108, t89) # t109: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t110 = prims.mul(t109, t109) # t110: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t111 = prims.sum(t110, (2,)) # t111: \"cuda:0 f32[1, 512]\"\n",
- " # t112 = prims.broadcast_in_dim(t111, [1, 512, 1], [0, 1]) # t112: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t113 = prims.div(t112, 4096.0) # t113: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t114 = prims.add(t113, 1e-05) # t114: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t115 = prims.rsqrt(t114) # t115: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t116 = prims.broadcast_in_dim(t115, (1, 512, 4096), (0, 1, 2)) # t116: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t117 = prims.mul(t109, t116) # t117: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t119 = prims.mul(t117, t118) # t119: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t108\n",
- " t120 = torch.nn.functional.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t120 = ltorch.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t120 = prims.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n",
- " t121 = torch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t121 = ltorch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t121 = prims.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " del t120\n",
- " t122 = torch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t122 = ltorch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t122 = prims.transpose(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " del t121\n",
- " (t123, t124, t125) = torch.split(t122, (1, 1, 1), 2)\n",
- " # (t123, t124, t125) = ltorch.split(t122, (1, 1, 1), 2)\n",
- " # t123 = prims.slice_prim(t122, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t123: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t124 = prims.slice_prim(t122, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t124: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t125 = prims.slice_prim(t122, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t125: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " del t122\n",
- " t126 = torch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t126 = ltorch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t126 = prims.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t123\n",
- " t127 = torch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t127 = ltorch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t127 = prims.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t124\n",
- " t128 = torch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t128 = ltorch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t128 = prims.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t125\n",
- " t149 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t149: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " t151 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t151: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " t129 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t129: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t126\n",
- " t139 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t139: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t127\n",
- " t130 = torch_slice_prim_impl(t129, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t130: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t131 = torch_slice_prim_impl(t129, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t131: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t141 = torch_slice_prim_impl(t139, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t141: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t140 = torch_slice_prim_impl(t139, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t140: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " [t132, t142] = nvFusion6(t131, t141)\n",
- " # t132 = prims.neg(t131) # t132: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " # t142 = prims.neg(t141) # t142: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " del t131, t141\n",
- " t143 = torch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t143 = ltorch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t143 = prims.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t142, t140\n",
- " t133 = torch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t133 = ltorch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t133 = prims.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t132, t130\n",
- " [t138, t148] = nvFusion7(t129, t133, t139, t143, t63, t65)\n",
- " # t145 = prims.mul(t139, t63) # t145: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t147 = prims.mul(t143, t65) # t147: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t148 = prims.add(t145, t147) # t148: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t135 = prims.mul(t129, t63) # t135: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t137 = prims.mul(t133, t65) # t137: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t138 = prims.add(t135, t137) # t138: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t129, t133, t139, t143\n",
- " t150 = torch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t150 = ltorch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t150 = prims.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t138, t149\n",
- " t152 = torch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t152 = ltorch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t152 = prims.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t148, t151\n",
- " (t153, t154, t155, t156) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t150, t152, t128, None, 0.0, True, 0.08838834764831843)\n",
- " t157 = torch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t157 = ltorch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t157 = prims.transpose(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " t158 = torch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t158 = ltorch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t158 = prims.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t157\n",
- " t159 = torch.nn.functional.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t159 = ltorch.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t159 = prims.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t160, t166, t170] = nvFusion8(t109, t159, t169)\n",
- " # t160 = prims.add(t159, t109) # t160: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t161 = prims.mul(t160, t160) # t161: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t162 = prims.sum(t161, (2,)) # t162: \"cuda:0 f32[1, 512]\"\n",
- " # t163 = prims.broadcast_in_dim(t162, [1, 512, 1], [0, 1]) # t163: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t164 = prims.div(t163, 4096.0) # t164: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t165 = prims.add(t164, 1e-05) # t165: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t166 = prims.rsqrt(t165) # t166: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t167 = prims.broadcast_in_dim(t166, (1, 512, 4096), (0, 1, 2)) # t167: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t168 = prims.mul(t160, t167) # t168: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t170 = prims.mul(t168, t169) # t170: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t159\n",
- " t172 = torch.nn.functional.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t172 = ltorch.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t172 = prims.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t171 = torch.nn.functional.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t171 = ltorch.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t171 = prims.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n",
- " [t178] = nvFusion9(t171, t172)\n",
- " # t173 = prims.neg(t171) # t173: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t174 = prims.exp(t173) # t174: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t175 = prims.add(1.0, t174) # t175: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t176 = prims.reciprocal(t175) # t176: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t177 = prims.mul(t171, t176) # t177: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t178 = prims.mul(t177, t172) # t178: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t179 = torch.nn.functional.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t179 = ltorch.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t179 = prims.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t180, t186, t190] = nvFusion10(t160, t179, t189)\n",
- " # t180 = prims.add(t179, t160) # t180: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t181 = prims.mul(t180, t180) # t181: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t182 = prims.sum(t181, (2,)) # t182: \"cuda:0 f32[1, 512]\"\n",
- " # t183 = prims.broadcast_in_dim(t182, [1, 512, 1], [0, 1]) # t183: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t184 = prims.div(t183, 4096.0) # t184: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t185 = prims.add(t184, 1e-05) # t185: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t186 = prims.rsqrt(t185) # t186: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t187 = prims.broadcast_in_dim(t186, (1, 512, 4096), (0, 1, 2)) # t187: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t188 = prims.mul(t180, t187) # t188: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t190 = prims.mul(t188, t189) # t190: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t179\n",
- " t191 = torch.nn.functional.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t191 = ltorch.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t191 = prims.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n",
- " t192 = torch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t192 = ltorch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t192 = prims.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " del t191\n",
- " t193 = torch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t193 = ltorch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t193 = prims.transpose(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " del t192\n",
- " (t194, t195, t196) = torch.split(t193, (1, 1, 1), 2)\n",
- " # (t194, t195, t196) = ltorch.split(t193, (1, 1, 1), 2)\n",
- " # t194 = prims.slice_prim(t193, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t194: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t195 = prims.slice_prim(t193, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t195: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t196 = prims.slice_prim(t193, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t196: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " del t193\n",
- " t197 = torch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t197 = ltorch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t197 = prims.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t194\n",
- " t198 = torch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t198 = ltorch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t198 = prims.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t195\n",
- " t199 = torch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t199 = ltorch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t199 = prims.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t196\n",
- " t200 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t200: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " t210 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t210: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " t220 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t220: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " del t197\n",
- " t222 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t222: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " del t198\n",
- " t201 = torch_slice_prim_impl(t200, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t201: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t202 = torch_slice_prim_impl(t200, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t202: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t211 = torch_slice_prim_impl(t210, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t211: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t212 = torch_slice_prim_impl(t210, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t212: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " [t203, t213] = nvFusion11(t202, t212)\n",
- " # t203 = prims.neg(t202) # t203: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " del t202, t212\n",
- " t214 = torch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t214 = ltorch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t214 = prims.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t213, t211\n",
- " t204 = torch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t204 = ltorch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t204 = prims.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t203, t201\n",
- " [t209, t219] = nvFusion12(t200, t204, t210, t214, t63, t65)\n",
- " # t216 = prims.mul(t210, t63) # t216: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t218 = prims.mul(t214, t65) # t218: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t219 = prims.add(t216, t218) # t219: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t206 = prims.mul(t200, t63) # t206: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t208 = prims.mul(t204, t65) # t208: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t209 = prims.add(t206, t208) # t209: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t200, t204, t210, t214\n",
- " t223 = torch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t223 = ltorch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t223 = prims.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t219, t222\n",
- " t221 = torch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t221 = ltorch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t221 = prims.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t209, t220\n",
- " (t224, t225, t226, t227) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t221, t223, t199, None, 0.0, True, 0.08838834764831843)\n",
- " t228 = torch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t228 = ltorch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t228 = prims.transpose(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " t229 = torch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t229 = ltorch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t229 = prims.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t228\n",
- " t230 = torch.nn.functional.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t230 = ltorch.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t230 = prims.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t231, t237, t241] = nvFusion13(t180, t230, t240)\n",
- " # t231 = prims.add(t230, t180) # t231: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t232 = prims.mul(t231, t231) # t232: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " t122 = torch.nn.functional.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t122 = ltorch.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1867 = ltorch.reshape(t0, [512]) # t1867: \"cuda:0 i64[512]\"\n",
+ " # t1867 = prims.reshape(t0, (512,)) # t1867: \"cuda:0 i64[512]\"\n",
+ " # t1868 = prims.take(t117, t1867, 0) # t1868: \"cuda:0 bf16[512, 4096]\"\n",
+ " # t122 = ltorch.reshape(t1868, [1, 512, 4096]) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t122 = prims.reshape(t1868, (1, 512, 4096)) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t118 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t118: \"cuda:0 f32[512, 128]\"\n",
+ " t119 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t119: \"cuda:0 f32[512, 128]\"\n",
+ " t2015 = torch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2015 = ltorch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2015 = prims.broadcast_in_dim(t53, [1, 4096], [1]) # t2015: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2016 = torch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2016 = ltorch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2016 = prims.broadcast_in_dim(t2015, [1, 1, 4096], [0, 2]) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2015\n",
+ " t133 = Tensor.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t133 = ltorch.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t133 = prims.broadcast_in_dim(t2016, (1, 512, 4096), (0, 1, 2)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2016\n",
+ " t2356 = torch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2356 = ltorch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2356 = prims.broadcast_in_dim(t82, [1, 4096], [1]) # t2356: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2357 = torch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2357 = ltorch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2357 = prims.broadcast_in_dim(t2356, [1, 1, 4096], [0, 2]) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2356\n",
+ " t1609 = Tensor.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1609 = ltorch.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1609 = prims.broadcast_in_dim(t2357, (1, 512, 4096), (0, 1, 2)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2357\n",
+ " t2359 = torch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2359 = ltorch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2359 = prims.broadcast_in_dim(t58, [1, 4096], [1]) # t2359: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2360 = torch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2360 = ltorch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2360 = prims.broadcast_in_dim(t2359, [1, 1, 4096], [0, 2]) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2359\n",
+ " t1645 = Tensor.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1645 = ltorch.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1645 = prims.broadcast_in_dim(t2360, (1, 512, 4096), (0, 1, 2)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2360\n",
+ " t2044 = torch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2044 = ltorch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2044 = prims.broadcast_in_dim(t69, [1, 4096], [1]) # t2044: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2045 = torch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2045 = ltorch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2045 = prims.broadcast_in_dim(t2044, [1, 1, 4096], [0, 2]) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2044\n",
+ " t205 = Tensor.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t205 = ltorch.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t205 = prims.broadcast_in_dim(t2045, (1, 512, 4096), (0, 1, 2)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2045\n",
+ " t2380 = torch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2380 = ltorch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2380 = prims.broadcast_in_dim(t83, [1, 4096], [1]) # t2380: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2381 = torch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2381 = ltorch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2381 = prims.broadcast_in_dim(t2380, [1, 1, 4096], [0, 2]) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2380\n",
+ " t1717 = Tensor.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1717 = ltorch.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1717 = prims.broadcast_in_dim(t2381, (1, 512, 4096), (0, 1, 2)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2381\n",
+ " t2047 = torch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2047 = ltorch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2047 = prims.broadcast_in_dim(t60, [1, 4096], [1]) # t2047: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2048 = torch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2048 = ltorch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2048 = prims.broadcast_in_dim(t2047, [1, 1, 4096], [0, 2]) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2047\n",
+ " t241 = Tensor.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t241 = ltorch.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t241 = prims.broadcast_in_dim(t2048, (1, 512, 4096), (0, 1, 2)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2048\n",
+ " t2383 = torch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2383 = ltorch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2383 = prims.broadcast_in_dim(t59, [1, 4096], [1]) # t2383: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2384 = torch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2384 = ltorch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2384 = prims.broadcast_in_dim(t2383, [1, 1, 4096], [0, 2]) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2383\n",
+ " t1753 = Tensor.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1753 = ltorch.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1753 = prims.broadcast_in_dim(t2384, (1, 512, 4096), (0, 1, 2)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2384\n",
+ " t2068 = torch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2068 = ltorch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2068 = prims.broadcast_in_dim(t70, [1, 4096], [1]) # t2068: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2069 = torch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2069 = ltorch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2069 = prims.broadcast_in_dim(t2068, [1, 1, 4096], [0, 2]) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2068\n",
+ " t313 = Tensor.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t313 = ltorch.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t313 = prims.broadcast_in_dim(t2069, (1, 512, 4096), (0, 1, 2)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2069\n",
+ " t2404 = torch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2404 = ltorch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2404 = prims.broadcast_in_dim(t84, [1, 4096], [1]) # t2404: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2405 = torch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2405 = ltorch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2405 = prims.broadcast_in_dim(t2404, [1, 1, 4096], [0, 2]) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2404\n",
+ " t1825 = Tensor.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1825 = ltorch.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1825 = prims.broadcast_in_dim(t2405, (1, 512, 4096), (0, 1, 2)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2405\n",
+ " t2071 = torch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2071 = ltorch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2071 = prims.broadcast_in_dim(t61, [1, 4096], [1]) # t2071: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2072 = torch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2072 = ltorch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2072 = prims.broadcast_in_dim(t2071, [1, 1, 4096], [0, 2]) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2071\n",
+ " t349 = Tensor.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t349 = ltorch.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t349 = prims.broadcast_in_dim(t2072, (1, 512, 4096), (0, 1, 2)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2072\n",
+ " t2407 = torch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2407 = ltorch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2407 = prims.broadcast_in_dim(t52, [1, 4096], [1]) # t2407: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2408 = torch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2408 = ltorch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2408 = prims.broadcast_in_dim(t2407, [1, 1, 4096], [0, 2]) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2407\n",
+ " t1861 = Tensor.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1861 = ltorch.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1861 = prims.broadcast_in_dim(t2408, (1, 512, 4096), (0, 1, 2)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2408\n",
+ " t2095 = torch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2095 = ltorch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2095 = prims.broadcast_in_dim(t62, [1, 4096], [1]) # t2095: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2096 = torch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2096 = ltorch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2096 = prims.broadcast_in_dim(t2095, [1, 1, 4096], [0, 2]) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2095\n",
+ " t457 = Tensor.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t457 = ltorch.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t457 = prims.broadcast_in_dim(t2096, (1, 512, 4096), (0, 1, 2)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2096\n",
+ " t2092 = torch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2092 = ltorch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2092 = prims.broadcast_in_dim(t71, [1, 4096], [1]) # t2092: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2093 = torch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2093 = ltorch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2093 = prims.broadcast_in_dim(t2092, [1, 1, 4096], [0, 2]) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2092\n",
+ " t421 = Tensor.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t421 = ltorch.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t421 = prims.broadcast_in_dim(t2093, (1, 512, 4096), (0, 1, 2)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2093\n",
+ " t2116 = torch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2116 = ltorch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2116 = prims.broadcast_in_dim(t72, [1, 4096], [1]) # t2116: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2117 = torch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2117 = ltorch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2117 = prims.broadcast_in_dim(t2116, [1, 1, 4096], [0, 2]) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2116\n",
+ " t529 = Tensor.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t529 = ltorch.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t529 = prims.broadcast_in_dim(t2117, (1, 512, 4096), (0, 1, 2)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2117\n",
+ " t2119 = torch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2119 = ltorch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2119 = prims.broadcast_in_dim(t63, [1, 4096], [1]) # t2119: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2120 = torch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2120 = ltorch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2120 = prims.broadcast_in_dim(t2119, [1, 1, 4096], [0, 2]) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2119\n",
+ " t565 = Tensor.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t565 = ltorch.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t565 = prims.broadcast_in_dim(t2120, (1, 512, 4096), (0, 1, 2)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2120\n",
+ " t2140 = torch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2140 = ltorch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2140 = prims.broadcast_in_dim(t73, [1, 4096], [1]) # t2140: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2141 = torch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2141 = ltorch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2141 = prims.broadcast_in_dim(t2140, [1, 1, 4096], [0, 2]) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2140\n",
+ " t637 = Tensor.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t637 = ltorch.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t637 = prims.broadcast_in_dim(t2141, (1, 512, 4096), (0, 1, 2)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2141\n",
+ " t2143 = torch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2143 = ltorch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2143 = prims.broadcast_in_dim(t64, [1, 4096], [1]) # t2143: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2144 = torch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2144 = ltorch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2144 = prims.broadcast_in_dim(t2143, [1, 1, 4096], [0, 2]) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2143\n",
+ " t673 = Tensor.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t673 = ltorch.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t673 = prims.broadcast_in_dim(t2144, (1, 512, 4096), (0, 1, 2)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2144\n",
+ " t2164 = torch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2164 = ltorch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2164 = prims.broadcast_in_dim(t74, [1, 4096], [1]) # t2164: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2165 = torch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2165 = ltorch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2165 = prims.broadcast_in_dim(t2164, [1, 1, 4096], [0, 2]) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2164\n",
+ " t745 = Tensor.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t745 = ltorch.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t745 = prims.broadcast_in_dim(t2165, (1, 512, 4096), (0, 1, 2)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2165\n",
+ " t2167 = torch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2167 = ltorch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2167 = prims.broadcast_in_dim(t65, [1, 4096], [1]) # t2167: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2168 = torch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2168 = ltorch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2168 = prims.broadcast_in_dim(t2167, [1, 1, 4096], [0, 2]) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2167\n",
+ " t781 = Tensor.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t781 = ltorch.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t781 = prims.broadcast_in_dim(t2168, (1, 512, 4096), (0, 1, 2)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2168\n",
+ " t2188 = torch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2188 = ltorch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2188 = prims.broadcast_in_dim(t75, [1, 4096], [1]) # t2188: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2189 = torch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2189 = ltorch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2189 = prims.broadcast_in_dim(t2188, [1, 1, 4096], [0, 2]) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2188\n",
+ " t853 = Tensor.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t853 = ltorch.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t853 = prims.broadcast_in_dim(t2189, (1, 512, 4096), (0, 1, 2)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2189\n",
+ " t2191 = torch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2191 = ltorch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2191 = prims.broadcast_in_dim(t66, [1, 4096], [1]) # t2191: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2192 = torch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2192 = ltorch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2192 = prims.broadcast_in_dim(t2191, [1, 1, 4096], [0, 2]) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2191\n",
+ " t889 = Tensor.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t889 = ltorch.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t889 = prims.broadcast_in_dim(t2192, (1, 512, 4096), (0, 1, 2)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2192\n",
+ " t2212 = torch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2212 = ltorch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2212 = prims.broadcast_in_dim(t76, [1, 4096], [1]) # t2212: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2213 = torch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2213 = ltorch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2213 = prims.broadcast_in_dim(t2212, [1, 1, 4096], [0, 2]) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2212\n",
+ " t961 = Tensor.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t961 = ltorch.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t961 = prims.broadcast_in_dim(t2213, (1, 512, 4096), (0, 1, 2)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2213\n",
+ " t2215 = torch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2215 = ltorch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2215 = prims.broadcast_in_dim(t67, [1, 4096], [1]) # t2215: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2216 = torch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2216 = ltorch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2216 = prims.broadcast_in_dim(t2215, [1, 1, 4096], [0, 2]) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2215\n",
+ " t997 = Tensor.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t997 = ltorch.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t997 = prims.broadcast_in_dim(t2216, (1, 512, 4096), (0, 1, 2)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2216\n",
+ " t2236 = torch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2236 = ltorch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2236 = prims.broadcast_in_dim(t77, [1, 4096], [1]) # t2236: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2237 = torch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2237 = ltorch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2237 = prims.broadcast_in_dim(t2236, [1, 1, 4096], [0, 2]) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2236\n",
+ " t1069 = Tensor.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1069 = ltorch.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1069 = prims.broadcast_in_dim(t2237, (1, 512, 4096), (0, 1, 2)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2237\n",
+ " t2239 = torch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2239 = ltorch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2239 = prims.broadcast_in_dim(t68, [1, 4096], [1]) # t2239: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2240 = torch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2240 = ltorch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2240 = prims.broadcast_in_dim(t2239, [1, 1, 4096], [0, 2]) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2239\n",
+ " t1105 = Tensor.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1105 = ltorch.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1105 = prims.broadcast_in_dim(t2240, (1, 512, 4096), (0, 1, 2)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2240\n",
+ " t2260 = torch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2260 = ltorch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2260 = prims.broadcast_in_dim(t78, [1, 4096], [1]) # t2260: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2261 = torch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2261 = ltorch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2261 = prims.broadcast_in_dim(t2260, [1, 1, 4096], [0, 2]) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2260\n",
+ " t1177 = Tensor.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1177 = ltorch.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1177 = prims.broadcast_in_dim(t2261, (1, 512, 4096), (0, 1, 2)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2261\n",
+ " t2263 = torch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2263 = ltorch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2263 = prims.broadcast_in_dim(t54, [1, 4096], [1]) # t2263: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2264 = torch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2264 = ltorch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2264 = prims.broadcast_in_dim(t2263, [1, 1, 4096], [0, 2]) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2263\n",
+ " t1213 = Tensor.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1213 = ltorch.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1213 = prims.broadcast_in_dim(t2264, (1, 512, 4096), (0, 1, 2)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2264\n",
+ " t2284 = torch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2284 = ltorch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2284 = prims.broadcast_in_dim(t79, [1, 4096], [1]) # t2284: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2285 = torch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2285 = ltorch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2285 = prims.broadcast_in_dim(t2284, [1, 1, 4096], [0, 2]) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2284\n",
+ " t1285 = Tensor.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1285 = ltorch.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1285 = prims.broadcast_in_dim(t2285, (1, 512, 4096), (0, 1, 2)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2285\n",
+ " t2287 = torch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2287 = ltorch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2287 = prims.broadcast_in_dim(t55, [1, 4096], [1]) # t2287: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2288 = torch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2288 = ltorch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2288 = prims.broadcast_in_dim(t2287, [1, 1, 4096], [0, 2]) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2287\n",
+ " t1321 = Tensor.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1321 = ltorch.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1321 = prims.broadcast_in_dim(t2288, (1, 512, 4096), (0, 1, 2)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2288\n",
+ " t2308 = torch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2308 = ltorch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2308 = prims.broadcast_in_dim(t80, [1, 4096], [1]) # t2308: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2309 = torch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2309 = ltorch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2309 = prims.broadcast_in_dim(t2308, [1, 1, 4096], [0, 2]) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2308\n",
+ " t1393 = Tensor.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1393 = ltorch.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1393 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2309\n",
+ " t2311 = torch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2311 = ltorch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2311 = prims.broadcast_in_dim(t56, [1, 4096], [1]) # t2311: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2312 = torch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2312 = ltorch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2312 = prims.broadcast_in_dim(t2311, [1, 1, 4096], [0, 2]) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2311\n",
+ " t1429 = Tensor.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1429 = ltorch.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1429 = prims.broadcast_in_dim(t2312, (1, 512, 4096), (0, 1, 2)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2312\n",
+ " t2332 = torch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2332 = ltorch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2332 = prims.broadcast_in_dim(t81, [1, 4096], [1]) # t2332: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2333 = torch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2333 = ltorch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2333 = prims.broadcast_in_dim(t2332, [1, 1, 4096], [0, 2]) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2332\n",
+ " t1501 = Tensor.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1501 = ltorch.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1501 = prims.broadcast_in_dim(t2333, (1, 512, 4096), (0, 1, 2)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2333\n",
+ " t2335 = torch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2335 = ltorch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n",
+ " # t2335 = prims.broadcast_in_dim(t57, [1, 4096], [1]) # t2335: \"cuda:0 bf16[1, 4096]\"\n",
+ " t2336 = torch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2336 = ltorch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " # t2336 = prims.broadcast_in_dim(t2335, [1, 1, 4096], [0, 2]) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n",
+ " del t2335\n",
+ " t1537 = Tensor.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1537 = ltorch.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1537 = prims.broadcast_in_dim(t2336, (1, 512, 4096), (0, 1, 2)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t2336\n",
+ " t2036 = torch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n",
+ " # t2036 = ltorch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n",
+ " # t2036 = prims.broadcast_in_dim(t118, [1, 512, 128], [1, 2]) # t2036: \"cuda:0 f32[1, 512, 128]\"\n",
+ " del t118\n",
+ " t2037 = torch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n",
+ " # t2037 = ltorch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n",
+ " # t2037 = prims.broadcast_in_dim(t2036, [1, 1, 512, 128], [0, 2, 3]) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n",
+ " del t2036\n",
+ " t154 = Tensor.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t154 = ltorch.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t154 = prims.broadcast_in_dim(t2037, (1, 32, 512, 128), (0, 1, 2, 3)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " del t2037\n",
+ " t2039 = torch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n",
+ " # t2039 = ltorch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n",
+ " # t2039 = prims.broadcast_in_dim(t119, [1, 512, 128], [1, 2]) # t2039: \"cuda:0 f32[1, 512, 128]\"\n",
+ " del t119\n",
+ " t2040 = torch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n",
+ " # t2040 = ltorch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n",
+ " # t2040 = prims.broadcast_in_dim(t2039, [1, 1, 512, 128], [0, 2, 3]) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n",
+ " del t2039\n",
+ " t157 = Tensor.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t157 = ltorch.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t157 = prims.broadcast_in_dim(t2040, (1, 32, 512, 128), (0, 1, 2, 3)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " del t2040\n",
+ " [t129, t137] = nvFusion0(t122, t133)\n",
+ " # t123 = prims.convert_element_type(t122, dtypes.float32) # t123: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t124 = prims.mul(t123, t123) # t124: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t125 = prims.sum(t124, (2,)) # t125: \"cuda:0 f32[1, 512]\"\n",
+ " # t126 = prims.broadcast_in_dim(t125, [1, 512, 1], [0, 1]) # t126: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t127 = prims.div(t126, 4096.0) # t127: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t128 = prims.add(t127, 1e-05) # t128: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t129 = prims.rsqrt(t128) # t129: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t130 = prims.broadcast_in_dim(t129, (1, 512, 4096), (0, 1, 2)) # t130: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t131 = prims.mul(t123, t130) # t131: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t135 = prims.convert_element_type(t133, dtypes.float32) # t135: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t136 = prims.mul(t131, t135) # t136: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t138 = torch.nn.functional.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t138 = ltorch.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t138 = prims.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t139 = torch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t139 = ltorch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t139 = prims.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t138\n",
+ " t140 = torch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t140 = ltorch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t140 = prims.transpose(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t139\n",
+ " (t141, t142, t143) = torch.split(t140, (1, 1, 1), 2)\n",
+ " # (t141, t142, t143) = ltorch.split(t140, (1, 1, 1), 2)\n",
+ " # t141 = prims.slice_prim(t140, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t141: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t142 = prims.slice_prim(t140, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t142: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t143 = prims.slice_prim(t140, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t143: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t140\n",
+ " t144 = torch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t144 = ltorch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t144 = prims.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t141\n",
+ " t145 = torch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t145 = ltorch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t145 = prims.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t142\n",
+ " t146 = torch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t146 = ltorch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t146 = prims.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t143\n",
+ " t147 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t147: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t162 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t162: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t177 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t177: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t144\n",
+ " t179 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t179: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t145\n",
+ " t149 = torch_slice_prim_impl(t147, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t149: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t148 = torch_slice_prim_impl(t147, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t148: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t163 = torch_slice_prim_impl(t162, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t163: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t164 = torch_slice_prim_impl(t162, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t164: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t152, t167] = nvFusion1(t147, t149, t162, t164)\n",
+ " # t150 = prims.convert_element_type(t149, dtypes.float32) # t150: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t151 = prims.neg(t150) # t151: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t152 = prims.convert_element_type(t151, dtypes.bfloat16) # t152: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t165 = prims.convert_element_type(t164, dtypes.float32) # t165: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t166 = prims.neg(t165) # t166: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t167 = prims.convert_element_type(t166, dtypes.bfloat16) # t167: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t149, t164\n",
+ " t168 = torch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t168 = ltorch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t168 = prims.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t167, t163\n",
+ " t153 = torch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t153 = ltorch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t153 = prims.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t152, t148\n",
+ " [t161, t176] = nvFusion2(t147, t153, t154, t157, t162, t168)\n",
+ " # t155 = prims.convert_element_type(t147, dtypes.float32) # t155: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t170 = prims.convert_element_type(t162, dtypes.float32) # t170: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t156 = prims.mul(t155, t154) # t156: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t158 = prims.convert_element_type(t153, dtypes.float32) # t158: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t159 = prims.mul(t158, t157) # t159: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t160 = prims.add(t156, t159) # t160: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t161 = prims.convert_element_type(t160, dtypes.bfloat16) # t161: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t171 = prims.mul(t170, t154) # t171: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t173 = prims.convert_element_type(t168, dtypes.float32) # t173: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t174 = prims.mul(t173, t157) # t174: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t175 = prims.add(t171, t174) # t175: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t176 = prims.convert_element_type(t175, dtypes.bfloat16) # t176: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t147, t153, t162, t168\n",
+ " t178 = torch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t178 = ltorch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t178 = prims.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t161, t177\n",
+ " t180 = torch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t180 = ltorch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t180 = prims.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t176, t179\n",
+ " (t181, t182, t183, t184, _, _, t185, t186, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t178, t180, t146, 0.0, True, scale=0.08838834764831843)\n",
+ " t188 = torch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t188 = ltorch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t188 = prims.transpose(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t189 = torch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t189 = ltorch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t189 = prims.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t188\n",
+ " t190 = torch.nn.functional.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t190 = ltorch.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t190 = prims.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t194, t201, t209] = nvFusion3(t122, t190, t205)\n",
+ " # t191 = prims.convert_element_type(t190, dtypes.float32) # t191: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t192 = prims.convert_element_type(t122, dtypes.float32) # t192: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t193 = prims.add(t191, t192) # t193: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t194 = prims.convert_element_type(t193, dtypes.bfloat16) # t194: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t196 = prims.mul(t193, t193) # t196: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t197 = prims.sum(t196, (2,)) # t197: \"cuda:0 f32[1, 512]\"\n",
+ " # t198 = prims.broadcast_in_dim(t197, [1, 512, 1], [0, 1]) # t198: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t199 = prims.div(t198, 4096.0) # t199: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t200 = prims.add(t199, 1e-05) # t200: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t201 = prims.rsqrt(t200) # t201: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t202 = prims.broadcast_in_dim(t201, (1, 512, 4096), (0, 1, 2)) # t202: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t203 = prims.mul(t193, t202) # t203: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t207 = prims.convert_element_type(t205, dtypes.float32) # t207: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t208 = prims.mul(t203, t207) # t208: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t209 = prims.convert_element_type(t208, dtypes.bfloat16) # t209: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t210 = torch.nn.functional.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t210 = ltorch.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t210 = prims.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t211 = torch.nn.functional.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t211 = ltorch.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t211 = prims.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t225] = nvFusion4(t210, t211)\n",
+ " # t212 = prims.convert_element_type(t210, dtypes.float32) # t212: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t214 = prims.exp(t213) # t214: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t215 = prims.add(1.0, t214) # t215: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t216 = prims.reciprocal(t215) # t216: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t220 = prims.mul(t212, t216) # t220: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t223 = prims.convert_element_type(t211, dtypes.float32) # t223: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t224 = prims.mul(t220, t223) # t224: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t225 = prims.convert_element_type(t224, dtypes.bfloat16) # t225: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t226 = torch.nn.functional.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t226 = ltorch.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t226 = prims.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t230, t237, t245] = nvFusion5(t194, t226, t241)\n",
+ " # t228 = prims.convert_element_type(t194, dtypes.float32) # t228: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t227 = prims.convert_element_type(t226, dtypes.float32) # t227: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t229 = prims.add(t227, t228) # t229: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t232 = prims.mul(t229, t229) # t232: \"cuda:0 f32[1, 512, 4096]\"\n",
" # t233 = prims.sum(t232, (2,)) # t233: \"cuda:0 f32[1, 512]\"\n",
" # t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1]) # t234: \"cuda:0 f32[1, 512, 1]\"\n",
" # t235 = prims.div(t234, 4096.0) # t235: \"cuda:0 f32[1, 512, 1]\"\n",
" # t236 = prims.add(t235, 1e-05) # t236: \"cuda:0 f32[1, 512, 1]\"\n",
" # t237 = prims.rsqrt(t236) # t237: \"cuda:0 f32[1, 512, 1]\"\n",
" # t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2)) # t238: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t239 = prims.mul(t231, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t241 = prims.mul(t239, t240) # t241: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t230\n",
- " t242 = torch.nn.functional.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t242 = ltorch.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t242 = prims.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t243 = torch.nn.functional.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t243 = ltorch.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t243 = prims.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n",
- " [t249] = nvFusion14(t242, t243)\n",
- " # t244 = prims.neg(t242) # t244: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t245 = prims.exp(t244) # t245: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t246 = prims.add(1.0, t245) # t246: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t247 = prims.reciprocal(t246) # t247: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t248 = prims.mul(t242, t247) # t248: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t249 = prims.mul(t248, t243) # t249: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t250 = torch.nn.functional.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t250 = ltorch.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t250 = prims.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t251, t257, t261] = nvFusion15(t231, t250, t260)\n",
- " # t251 = prims.add(t250, t231) # t251: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t252 = prims.mul(t251, t251) # t252: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t253 = prims.sum(t252, (2,)) # t253: \"cuda:0 f32[1, 512]\"\n",
- " # t254 = prims.broadcast_in_dim(t253, [1, 512, 1], [0, 1]) # t254: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t255 = prims.div(t254, 4096.0) # t255: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t256 = prims.add(t255, 1e-05) # t256: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t257 = prims.rsqrt(t256) # t257: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t258 = prims.broadcast_in_dim(t257, (1, 512, 4096), (0, 1, 2)) # t258: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t259 = prims.mul(t251, t258) # t259: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t261 = prims.mul(t259, t260) # t261: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t239 = prims.mul(t229, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t243 = prims.convert_element_type(t241, dtypes.float32) # t243: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t244 = prims.mul(t239, t243) # t244: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t245 = prims.convert_element_type(t244, dtypes.bfloat16) # t245: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t246 = torch.nn.functional.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t246 = ltorch.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t246 = prims.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t247 = torch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t247 = ltorch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t247 = prims.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t246\n",
+ " t248 = torch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t248 = ltorch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t248 = prims.transpose(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t247\n",
+ " (t249, t250, t251) = torch.split(t248, (1, 1, 1), 2)\n",
+ " # (t249, t250, t251) = ltorch.split(t248, (1, 1, 1), 2)\n",
+ " # t249 = prims.slice_prim(t248, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t249: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t250 = prims.slice_prim(t248, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t250: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t251 = prims.slice_prim(t248, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t251: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t248\n",
+ " t252 = torch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t252 = ltorch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t252 = prims.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t249\n",
+ " t253 = torch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t253 = ltorch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t253 = prims.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
" del t250\n",
- " t262 = torch.nn.functional.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t262 = ltorch.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n",
- " # t262 = prims.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n",
- " t263 = torch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t263 = ltorch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " # t263 = prims.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n",
- " del t262\n",
- " t264 = torch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t264 = ltorch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " # t264 = prims.transpose(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n",
- " del t263\n",
- " (t265, t266, t267) = torch.split(t264, (1, 1, 1), 2)\n",
- " # (t265, t266, t267) = ltorch.split(t264, (1, 1, 1), 2)\n",
- " # t265 = prims.slice_prim(t264, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t265: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t266 = prims.slice_prim(t264, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t266: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " # t267 = prims.slice_prim(t264, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t267: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n",
- " del t264\n",
- " t268 = torch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t268 = ltorch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t268 = prims.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t265\n",
- " t269 = torch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t269 = ltorch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t269 = prims.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t266\n",
- " t270 = torch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t270 = ltorch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t270 = prims.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t267\n",
- " t271 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t271: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " t281 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " t291 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t291: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " del t268\n",
- " t293 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t293: \"cuda:0 f32[1, 32, 512, 0]\"\n",
- " del t269\n",
- " t272 = torch_slice_prim_impl(t271, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t272: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t273 = torch_slice_prim_impl(t271, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t282 = torch_slice_prim_impl(t281, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t282: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " t283 = torch_slice_prim_impl(t281, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t283: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " [t274, t284] = nvFusion16(t273, t283)\n",
+ " t254 = torch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t254 = ltorch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t254 = prims.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t251\n",
+ " t285 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t285: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " t287 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t287: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " t255 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t255: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t252\n",
+ " t270 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t270: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t253\n",
+ " t256 = torch_slice_prim_impl(t255, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t256: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t257 = torch_slice_prim_impl(t255, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t257: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t272 = torch_slice_prim_impl(t270, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t272: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t271 = torch_slice_prim_impl(t270, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t271: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t260, t275] = nvFusion6(t255, t257, t270, t272)\n",
+ " # t258 = prims.convert_element_type(t257, dtypes.float32) # t258: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t259 = prims.neg(t258) # t259: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t273 = prims.convert_element_type(t272, dtypes.float32) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n",
" # t274 = prims.neg(t273) # t274: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " # t284 = prims.neg(t283) # t284: \"cuda:0 f32[1, 32, 512, 64]\"\n",
- " del t273, t283\n",
- " t275 = torch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t275 = ltorch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t275 = prims.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t274, t272\n",
- " t285 = torch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t285 = ltorch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t285 = prims.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t284, t282\n",
- " [t280, t290] = nvFusion17(t271, t275, t281, t285, t63, t65)\n",
- " # t277 = prims.mul(t271, t63) # t277: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t279 = prims.mul(t275, t65) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t280 = prims.add(t277, t279) # t280: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t287 = prims.mul(t281, t63) # t287: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t289 = prims.mul(t285, t65) # t289: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t290 = prims.add(t287, t289) # t290: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t271, t275, t281, t285\n",
- " t292 = torch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t292 = ltorch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t292 = prims.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t280, t291\n",
- " t294 = torch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t294 = ltorch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " # t294 = prims.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n",
- " del t290, t293\n",
- " (t295, t296, t297, t298) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t292, t294, t270, None, 0.0, True, 0.08838834764831843)\n",
- " t299 = torch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t299 = ltorch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " # t299 = prims.transpose(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n",
- " t300 = torch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t300 = ltorch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t300 = prims.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t299\n",
- " t301 = torch.nn.functional.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t301 = ltorch.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t301 = prims.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t302, t308, t312] = nvFusion18(t251, t301, t311)\n",
- " # t302 = prims.add(t301, t251) # t302: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t303 = prims.mul(t302, t302) # t303: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t304 = prims.sum(t303, (2,)) # t304: \"cuda:0 f32[1, 512]\"\n",
- " # t305 = prims.broadcast_in_dim(t304, [1, 512, 1], [0, 1]) # t305: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t306 = prims.div(t305, 4096.0) # t306: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t307 = prims.add(t306, 1e-05) # t307: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t308 = prims.rsqrt(t307) # t308: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t309 = prims.broadcast_in_dim(t308, (1, 512, 4096), (0, 1, 2)) # t309: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t310 = prims.mul(t302, t309) # t310: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t312 = prims.mul(t310, t311) # t312: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t301\n",
- " t314 = torch.nn.functional.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t314 = ltorch.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t314 = prims.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t313 = torch.nn.functional.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t313 = ltorch.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t313 = prims.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n",
- " [t320] = nvFusion19(t313, t314)\n",
- " # t315 = prims.neg(t313) # t315: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t316 = prims.exp(t315) # t316: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t317 = prims.add(1.0, t316) # t317: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t318 = prims.reciprocal(t317) # t318: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t319 = prims.mul(t313, t318) # t319: \"cuda:0 f32[1, 512, 11008]\"\n",
- " # t320 = prims.mul(t319, t314) # t320: \"cuda:0 f32[1, 512, 11008]\"\n",
- " t321 = torch.nn.functional.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t321 = ltorch.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t321 = prims.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n",
- " [t322, t328, t332] = nvFusion20(t302, t321, t331)\n",
- " # t322 = prims.add(t321, t302) # t322: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t323 = prims.mul(t322, t322) # t323: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t324 = prims.sum(t323, (2,)) # t324: \"cuda:0 f32[1, 512]\"\n",
- " # t325 = prims.broadcast_in_dim(t324, [1, 512, 1], [0, 1]) # t325: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t326 = prims.div(t325, 4096.0) # t326: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t327 = prims.add(t326, 1e-05) # t327: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t328 = prims.rsqrt(t327) # t328: \"cuda:0 f32[1, 512, 1]\"\n",
- " # t329 = prims.broadcast_in_dim(t328, (1, 512, 4096), (0, 1, 2)) # t329: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t330 = prims.mul(t322, t329) # t330: \"cuda:0 f32[1, 512, 4096]\"\n",
- " # t332 = prims.mul(t330, t331) # t332: \"cuda:0 f32[1, 512, 4096]\"\n",
- " del t321\n",
- " t333 = torch.nn.functional.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n",
- " # t333 = ltorch.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n",
- " # t333 = prims.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n",
- " return {'output': t333, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33], 'flat_output': (t333,)}, ((t0, t10, t100, t101, t107, t109, t11, t115, t118, t119, t12, t128, t13, t14, t15, t150, t152, t153, t154, t155, t156, t158, t160, t166, t169, t170, t171, t172, t178, t180, t186, t189, t190, t199, t221, t223, t224, t225, t226, t227, t229, t231, t237, t240, t241, t242, t243, t249, t25, t251, t257, t26, t260, t261, t27, t270, t28, t29, t292, t294, t295, t296, t297, t298, t3, t30, t300, t302, t308, t31, t311, t312, t313, t314, t32, t320, t322, t328, t331, t332, t38, t4, t44, t47, t48, t5, t57, t6, t63, t65, t7, t79, t8, t81, t82, t83, t84, t85, t87, t89, t9, t95, t98, t99), (False, True, True, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2))"
+ " # t275 = prims.convert_element_type(t274, dtypes.bfloat16) # t275: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t257, t272\n",
+ " t261 = torch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t261 = ltorch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t261 = prims.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t260, t256\n",
+ " t276 = torch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t276 = ltorch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t276 = prims.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t275, t271\n",
+ " [t269, t284] = nvFusion7(t154, t157, t255, t261, t270, t276)\n",
+ " # t263 = prims.convert_element_type(t255, dtypes.float32) # t263: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t278 = prims.convert_element_type(t270, dtypes.float32) # t278: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t264 = prims.mul(t263, t154) # t264: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t266 = prims.convert_element_type(t261, dtypes.float32) # t266: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t267 = prims.mul(t266, t157) # t267: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t268 = prims.add(t264, t267) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t279 = prims.mul(t278, t154) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t281 = prims.convert_element_type(t276, dtypes.float32) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t282 = prims.mul(t281, t157) # t282: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t283 = prims.add(t279, t282) # t283: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t284 = prims.convert_element_type(t283, dtypes.bfloat16) # t284: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t255, t261, t270, t276\n",
+ " t288 = torch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t288 = ltorch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t288 = prims.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t284, t287\n",
+ " t286 = torch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t286 = ltorch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t286 = prims.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t269, t285\n",
+ " (t289, t290, t291, t292, _, _, t293, t294, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t286, t288, t254, 0.0, True, scale=0.08838834764831843)\n",
+ " t296 = torch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t296 = ltorch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t296 = prims.transpose(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t297 = torch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t297 = ltorch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t297 = prims.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t296\n",
+ " t298 = torch.nn.functional.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t298 = ltorch.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t298 = prims.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t302, t309, t317] = nvFusion8(t230, t298, t313)\n",
+ " # t300 = prims.convert_element_type(t230, dtypes.float32) # t300: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t299 = prims.convert_element_type(t298, dtypes.float32) # t299: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t301 = prims.add(t299, t300) # t301: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t302 = prims.convert_element_type(t301, dtypes.bfloat16) # t302: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t304 = prims.mul(t301, t301) # t304: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t305 = prims.sum(t304, (2,)) # t305: \"cuda:0 f32[1, 512]\"\n",
+ " # t306 = prims.broadcast_in_dim(t305, [1, 512, 1], [0, 1]) # t306: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t307 = prims.div(t306, 4096.0) # t307: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t308 = prims.add(t307, 1e-05) # t308: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t309 = prims.rsqrt(t308) # t309: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t310 = prims.broadcast_in_dim(t309, (1, 512, 4096), (0, 1, 2)) # t310: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t311 = prims.mul(t301, t310) # t311: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t315 = prims.convert_element_type(t313, dtypes.float32) # t315: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t316 = prims.mul(t311, t315) # t316: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t317 = prims.convert_element_type(t316, dtypes.bfloat16) # t317: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t318 = torch.nn.functional.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t318 = ltorch.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t318 = prims.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t319 = torch.nn.functional.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t319 = ltorch.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t319 = prims.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t333] = nvFusion9(t318, t319)\n",
+ " # t320 = prims.convert_element_type(t318, dtypes.float32) # t320: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t321 = prims.neg(t320) # t321: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t322 = prims.exp(t321) # t322: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t323 = prims.add(1.0, t322) # t323: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t324 = prims.reciprocal(t323) # t324: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t328 = prims.mul(t320, t324) # t328: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t331 = prims.convert_element_type(t319, dtypes.float32) # t331: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t332 = prims.mul(t328, t331) # t332: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t333 = prims.convert_element_type(t332, dtypes.bfloat16) # t333: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t334 = torch.nn.functional.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t334 = ltorch.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t334 = prims.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t338, t345, t353] = nvFusion10(t302, t334, t349)\n",
+ " # t336 = prims.convert_element_type(t302, dtypes.float32) # t336: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t335 = prims.convert_element_type(t334, dtypes.float32) # t335: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t337 = prims.add(t335, t336) # t337: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t340 = prims.mul(t337, t337) # t340: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t341 = prims.sum(t340, (2,)) # t341: \"cuda:0 f32[1, 512]\"\n",
+ " # t342 = prims.broadcast_in_dim(t341, [1, 512, 1], [0, 1]) # t342: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t343 = prims.div(t342, 4096.0) # t343: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t344 = prims.add(t343, 1e-05) # t344: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t345 = prims.rsqrt(t344) # t345: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t346 = prims.broadcast_in_dim(t345, (1, 512, 4096), (0, 1, 2)) # t346: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t347 = prims.mul(t337, t346) # t347: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t351 = prims.convert_element_type(t349, dtypes.float32) # t351: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t352 = prims.mul(t347, t351) # t352: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t354 = torch.nn.functional.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t354 = ltorch.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t354 = prims.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t355 = torch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t355 = ltorch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t355 = prims.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t354\n",
+ " t356 = torch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t356 = ltorch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t356 = prims.transpose(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t355\n",
+ " (t357, t358, t359) = torch.split(t356, (1, 1, 1), 2)\n",
+ " # (t357, t358, t359) = ltorch.split(t356, (1, 1, 1), 2)\n",
+ " # t357 = prims.slice_prim(t356, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t357: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t358 = prims.slice_prim(t356, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t358: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t359 = prims.slice_prim(t356, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t359: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t356\n",
+ " t360 = torch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t360 = ltorch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t360 = prims.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t357\n",
+ " t361 = torch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t361 = ltorch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t361 = prims.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t358\n",
+ " t362 = torch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t362 = ltorch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t362 = prims.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t359\n",
+ " t363 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t363: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t378 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t378: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t393 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t393: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t360\n",
+ " t395 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t395: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t361\n",
+ " t364 = torch_slice_prim_impl(t363, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t364: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t365 = torch_slice_prim_impl(t363, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t365: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t379 = torch_slice_prim_impl(t378, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t379: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t380 = torch_slice_prim_impl(t378, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t380: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t368, t383] = nvFusion11(t363, t365, t378, t380)\n",
+ " # t366 = prims.convert_element_type(t365, dtypes.float32) # t366: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t367 = prims.neg(t366) # t367: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t368 = prims.convert_element_type(t367, dtypes.bfloat16) # t368: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t381 = prims.convert_element_type(t380, dtypes.float32) # t381: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t382 = prims.neg(t381) # t382: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t383 = prims.convert_element_type(t382, dtypes.bfloat16) # t383: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t365, t380\n",
+ " t369 = torch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t369 = ltorch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t369 = prims.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t368, t364\n",
+ " t384 = torch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t384 = ltorch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t384 = prims.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t383, t379\n",
+ " [t377, t392] = nvFusion12(t154, t157, t363, t369, t378, t384)\n",
+ " # t371 = prims.convert_element_type(t363, dtypes.float32) # t371: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t386 = prims.convert_element_type(t378, dtypes.float32) # t386: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t372 = prims.mul(t371, t154) # t372: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t374 = prims.convert_element_type(t369, dtypes.float32) # t374: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t375 = prims.mul(t374, t157) # t375: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t376 = prims.add(t372, t375) # t376: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t377 = prims.convert_element_type(t376, dtypes.bfloat16) # t377: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t387 = prims.mul(t386, t154) # t387: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t389 = prims.convert_element_type(t384, dtypes.float32) # t389: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t390 = prims.mul(t389, t157) # t390: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t391 = prims.add(t387, t390) # t391: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t363, t369, t378, t384\n",
+ " t394 = torch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t394 = ltorch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t394 = prims.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t377, t393\n",
+ " t396 = torch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t396 = ltorch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t396 = prims.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t392, t395\n",
+ " (t397, t398, t399, t400, _, _, t401, t402, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t394, t396, t362, 0.0, True, scale=0.08838834764831843)\n",
+ " t404 = torch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t404 = ltorch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t404 = prims.transpose(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t405 = torch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t405 = ltorch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t405 = prims.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t404\n",
+ " t406 = torch.nn.functional.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t406 = ltorch.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t406 = prims.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t410, t417, t425] = nvFusion13(t338, t406, t421)\n",
+ " # t408 = prims.convert_element_type(t338, dtypes.float32) # t408: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t407 = prims.convert_element_type(t406, dtypes.float32) # t407: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t409 = prims.add(t407, t408) # t409: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t410 = prims.convert_element_type(t409, dtypes.bfloat16) # t410: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t412 = prims.mul(t409, t409) # t412: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t413 = prims.sum(t412, (2,)) # t413: \"cuda:0 f32[1, 512]\"\n",
+ " # t414 = prims.broadcast_in_dim(t413, [1, 512, 1], [0, 1]) # t414: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t415 = prims.div(t414, 4096.0) # t415: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t416 = prims.add(t415, 1e-05) # t416: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t417 = prims.rsqrt(t416) # t417: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t418 = prims.broadcast_in_dim(t417, (1, 512, 4096), (0, 1, 2)) # t418: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t419 = prims.mul(t409, t418) # t419: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t423 = prims.convert_element_type(t421, dtypes.float32) # t423: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t424 = prims.mul(t419, t423) # t424: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t425 = prims.convert_element_type(t424, dtypes.bfloat16) # t425: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t426 = torch.nn.functional.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t426 = ltorch.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t426 = prims.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t427 = torch.nn.functional.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t427 = ltorch.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t427 = prims.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t441] = nvFusion14(t426, t427)\n",
+ " # t428 = prims.convert_element_type(t426, dtypes.float32) # t428: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t429 = prims.neg(t428) # t429: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t430 = prims.exp(t429) # t430: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t431 = prims.add(1.0, t430) # t431: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t432 = prims.reciprocal(t431) # t432: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t436 = prims.mul(t428, t432) # t436: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t439 = prims.convert_element_type(t427, dtypes.float32) # t439: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t440 = prims.mul(t436, t439) # t440: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t441 = prims.convert_element_type(t440, dtypes.bfloat16) # t441: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t442 = torch.nn.functional.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t442 = ltorch.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t442 = prims.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t446, t453, t461] = nvFusion15(t410, t442, t457)\n",
+ " # t444 = prims.convert_element_type(t410, dtypes.float32) # t444: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t443 = prims.convert_element_type(t442, dtypes.float32) # t443: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t445 = prims.add(t443, t444) # t445: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t446 = prims.convert_element_type(t445, dtypes.bfloat16) # t446: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t448 = prims.mul(t445, t445) # t448: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t449 = prims.sum(t448, (2,)) # t449: \"cuda:0 f32[1, 512]\"\n",
+ " # t450 = prims.broadcast_in_dim(t449, [1, 512, 1], [0, 1]) # t450: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t451 = prims.div(t450, 4096.0) # t451: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t452 = prims.add(t451, 1e-05) # t452: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t453 = prims.rsqrt(t452) # t453: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t454 = prims.broadcast_in_dim(t453, (1, 512, 4096), (0, 1, 2)) # t454: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t455 = prims.mul(t445, t454) # t455: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t459 = prims.convert_element_type(t457, dtypes.float32) # t459: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t460 = prims.mul(t455, t459) # t460: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t461 = prims.convert_element_type(t460, dtypes.bfloat16) # t461: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t462 = torch.nn.functional.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t462 = ltorch.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t462 = prims.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t463 = torch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t463 = ltorch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t463 = prims.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t462\n",
+ " t464 = torch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t464 = ltorch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t464 = prims.transpose(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t463\n",
+ " (t465, t466, t467) = torch.split(t464, (1, 1, 1), 2)\n",
+ " # (t465, t466, t467) = ltorch.split(t464, (1, 1, 1), 2)\n",
+ " # t465 = prims.slice_prim(t464, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t465: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t466 = prims.slice_prim(t464, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t466: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t467 = prims.slice_prim(t464, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t467: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t464\n",
+ " t468 = torch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t468 = ltorch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t468 = prims.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t465\n",
+ " t469 = torch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t469 = ltorch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t469 = prims.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t466\n",
+ " t470 = torch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t470 = ltorch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t470 = prims.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t467\n",
+ " t471 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t471: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t486 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t486: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t501 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t501: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t468\n",
+ " t503 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t503: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t469\n",
+ " t472 = torch_slice_prim_impl(t471, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t472: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t473 = torch_slice_prim_impl(t471, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t473: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t487 = torch_slice_prim_impl(t486, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t487: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t488 = torch_slice_prim_impl(t486, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t488: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t476, t491] = nvFusion16(t471, t473, t486, t488)\n",
+ " # t474 = prims.convert_element_type(t473, dtypes.float32) # t474: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t475 = prims.neg(t474) # t475: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t476 = prims.convert_element_type(t475, dtypes.bfloat16) # t476: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t489 = prims.convert_element_type(t488, dtypes.float32) # t489: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t490 = prims.neg(t489) # t490: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t491 = prims.convert_element_type(t490, dtypes.bfloat16) # t491: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t473, t488\n",
+ " t477 = torch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t477 = ltorch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t477 = prims.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t476, t472\n",
+ " t492 = torch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t492 = ltorch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t492 = prims.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t491, t487\n",
+ " [t485, t500] = nvFusion17(t154, t157, t471, t477, t486, t492)\n",
+ " # t479 = prims.convert_element_type(t471, dtypes.float32) # t479: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t494 = prims.convert_element_type(t486, dtypes.float32) # t494: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t480 = prims.mul(t479, t154) # t480: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t482 = prims.convert_element_type(t477, dtypes.float32) # t482: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t483 = prims.mul(t482, t157) # t483: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t484 = prims.add(t480, t483) # t484: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t485 = prims.convert_element_type(t484, dtypes.bfloat16) # t485: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t495 = prims.mul(t494, t154) # t495: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t497 = prims.convert_element_type(t492, dtypes.float32) # t497: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t498 = prims.mul(t497, t157) # t498: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t499 = prims.add(t495, t498) # t499: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t500 = prims.convert_element_type(t499, dtypes.bfloat16) # t500: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t471, t477, t486, t492\n",
+ " t502 = torch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t502 = ltorch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t502 = prims.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t485, t501\n",
+ " t504 = torch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t504 = ltorch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t504 = prims.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t500, t503\n",
+ " (t505, t506, t507, t508, _, _, t509, t510, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t502, t504, t470, 0.0, True, scale=0.08838834764831843)\n",
+ " t512 = torch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t512 = ltorch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t512 = prims.transpose(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t513 = torch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t513 = ltorch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t513 = prims.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t512\n",
+ " t514 = torch.nn.functional.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t514 = ltorch.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t514 = prims.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t518, t525, t533] = nvFusion18(t446, t514, t529)\n",
+ " # t516 = prims.convert_element_type(t446, dtypes.float32) # t516: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t515 = prims.convert_element_type(t514, dtypes.float32) # t515: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t517 = prims.add(t515, t516) # t517: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t518 = prims.convert_element_type(t517, dtypes.bfloat16) # t518: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t520 = prims.mul(t517, t517) # t520: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t521 = prims.sum(t520, (2,)) # t521: \"cuda:0 f32[1, 512]\"\n",
+ " # t522 = prims.broadcast_in_dim(t521, [1, 512, 1], [0, 1]) # t522: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t523 = prims.div(t522, 4096.0) # t523: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t524 = prims.add(t523, 1e-05) # t524: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t525 = prims.rsqrt(t524) # t525: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t526 = prims.broadcast_in_dim(t525, (1, 512, 4096), (0, 1, 2)) # t526: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t527 = prims.mul(t517, t526) # t527: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t531 = prims.convert_element_type(t529, dtypes.float32) # t531: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t532 = prims.mul(t527, t531) # t532: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t534 = torch.nn.functional.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t534 = ltorch.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t534 = prims.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t535 = torch.nn.functional.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t535 = ltorch.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t535 = prims.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t549] = nvFusion19(t534, t535)\n",
+ " # t536 = prims.convert_element_type(t534, dtypes.float32) # t536: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t537 = prims.neg(t536) # t537: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t538 = prims.exp(t537) # t538: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t539 = prims.add(1.0, t538) # t539: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t540 = prims.reciprocal(t539) # t540: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t544 = prims.mul(t536, t540) # t544: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t547 = prims.convert_element_type(t535, dtypes.float32) # t547: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t548 = prims.mul(t544, t547) # t548: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t549 = prims.convert_element_type(t548, dtypes.bfloat16) # t549: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t550 = torch.nn.functional.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t550 = ltorch.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t550 = prims.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t554, t561, t569] = nvFusion20(t518, t550, t565)\n",
+ " # t552 = prims.convert_element_type(t518, dtypes.float32) # t552: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t551 = prims.convert_element_type(t550, dtypes.float32) # t551: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t553 = prims.add(t551, t552) # t553: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t556 = prims.mul(t553, t553) # t556: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t557 = prims.sum(t556, (2,)) # t557: \"cuda:0 f32[1, 512]\"\n",
+ " # t558 = prims.broadcast_in_dim(t557, [1, 512, 1], [0, 1]) # t558: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t559 = prims.div(t558, 4096.0) # t559: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t560 = prims.add(t559, 1e-05) # t560: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t561 = prims.rsqrt(t560) # t561: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t562 = prims.broadcast_in_dim(t561, (1, 512, 4096), (0, 1, 2)) # t562: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t563 = prims.mul(t553, t562) # t563: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t567 = prims.convert_element_type(t565, dtypes.float32) # t567: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t568 = prims.mul(t563, t567) # t568: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t569 = prims.convert_element_type(t568, dtypes.bfloat16) # t569: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t570 = torch.nn.functional.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t570 = ltorch.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t570 = prims.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t571 = torch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t571 = ltorch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t571 = prims.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t570\n",
+ " t572 = torch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t572 = ltorch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t572 = prims.transpose(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t571\n",
+ " (t573, t574, t575) = torch.split(t572, (1, 1, 1), 2)\n",
+ " # (t573, t574, t575) = ltorch.split(t572, (1, 1, 1), 2)\n",
+ " # t573 = prims.slice_prim(t572, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t573: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t574 = prims.slice_prim(t572, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t574: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t575 = prims.slice_prim(t572, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t575: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t572\n",
+ " t576 = torch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t576 = ltorch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t576 = prims.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t573\n",
+ " t577 = torch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t577 = ltorch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t577 = prims.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t574\n",
+ " t578 = torch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t578 = ltorch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t578 = prims.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t575\n",
+ " t579 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t579: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t594 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t594: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t609 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t609: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t576\n",
+ " t611 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t611: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t577\n",
+ " t580 = torch_slice_prim_impl(t579, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t580: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t581 = torch_slice_prim_impl(t579, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t581: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t595 = torch_slice_prim_impl(t594, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t595: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t596 = torch_slice_prim_impl(t594, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t596: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t584, t599] = nvFusion21(t579, t581, t594, t596)\n",
+ " # t582 = prims.convert_element_type(t581, dtypes.float32) # t582: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t583 = prims.neg(t582) # t583: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t584 = prims.convert_element_type(t583, dtypes.bfloat16) # t584: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t597 = prims.convert_element_type(t596, dtypes.float32) # t597: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t598 = prims.neg(t597) # t598: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t599 = prims.convert_element_type(t598, dtypes.bfloat16) # t599: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t581, t596\n",
+ " t600 = torch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t600 = ltorch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t600 = prims.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t599, t595\n",
+ " t585 = torch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t585 = ltorch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t585 = prims.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t584, t580\n",
+ " [t593, t608] = nvFusion22(t154, t157, t579, t585, t594, t600)\n",
+ " # t587 = prims.convert_element_type(t579, dtypes.float32) # t587: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t602 = prims.convert_element_type(t594, dtypes.float32) # t602: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t603 = prims.mul(t602, t154) # t603: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t605 = prims.convert_element_type(t600, dtypes.float32) # t605: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t606 = prims.mul(t605, t157) # t606: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t607 = prims.add(t603, t606) # t607: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t608 = prims.convert_element_type(t607, dtypes.bfloat16) # t608: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t588 = prims.mul(t587, t154) # t588: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t590 = prims.convert_element_type(t585, dtypes.float32) # t590: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t591 = prims.mul(t590, t157) # t591: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t592 = prims.add(t588, t591) # t592: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t593 = prims.convert_element_type(t592, dtypes.bfloat16) # t593: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t579, t585, t594, t600\n",
+ " t612 = torch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t612 = ltorch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t612 = prims.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t608, t611\n",
+ " t610 = torch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t610 = ltorch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t610 = prims.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t593, t609\n",
+ " (t613, t614, t615, t616, _, _, t617, t618, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t610, t612, t578, 0.0, True, scale=0.08838834764831843)\n",
+ " t620 = torch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t620 = ltorch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t620 = prims.transpose(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t621 = torch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t621 = ltorch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t621 = prims.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t620\n",
+ " t622 = torch.nn.functional.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t622 = ltorch.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t622 = prims.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t626, t633, t641] = nvFusion23(t554, t622, t637)\n",
+ " # t624 = prims.convert_element_type(t554, dtypes.float32) # t624: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t623 = prims.convert_element_type(t622, dtypes.float32) # t623: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t625 = prims.add(t623, t624) # t625: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t626 = prims.convert_element_type(t625, dtypes.bfloat16) # t626: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t628 = prims.mul(t625, t625) # t628: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t629 = prims.sum(t628, (2,)) # t629: \"cuda:0 f32[1, 512]\"\n",
+ " # t630 = prims.broadcast_in_dim(t629, [1, 512, 1], [0, 1]) # t630: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t631 = prims.div(t630, 4096.0) # t631: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t632 = prims.add(t631, 1e-05) # t632: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t633 = prims.rsqrt(t632) # t633: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t634 = prims.broadcast_in_dim(t633, (1, 512, 4096), (0, 1, 2)) # t634: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t635 = prims.mul(t625, t634) # t635: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t639 = prims.convert_element_type(t637, dtypes.float32) # t639: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t640 = prims.mul(t635, t639) # t640: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t641 = prims.convert_element_type(t640, dtypes.bfloat16) # t641: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t643 = torch.nn.functional.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t643 = ltorch.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t643 = prims.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t642 = torch.nn.functional.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t642 = ltorch.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t642 = prims.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t657] = nvFusion24(t642, t643)\n",
+ " # t644 = prims.convert_element_type(t642, dtypes.float32) # t644: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t645 = prims.neg(t644) # t645: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t646 = prims.exp(t645) # t646: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t647 = prims.add(1.0, t646) # t647: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t648 = prims.reciprocal(t647) # t648: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t652 = prims.mul(t644, t648) # t652: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t655 = prims.convert_element_type(t643, dtypes.float32) # t655: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t656 = prims.mul(t652, t655) # t656: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t657 = prims.convert_element_type(t656, dtypes.bfloat16) # t657: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t658 = torch.nn.functional.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t658 = ltorch.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t658 = prims.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t662, t669, t677] = nvFusion25(t626, t658, t673)\n",
+ " # t660 = prims.convert_element_type(t626, dtypes.float32) # t660: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t659 = prims.convert_element_type(t658, dtypes.float32) # t659: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t661 = prims.add(t659, t660) # t661: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t662 = prims.convert_element_type(t661, dtypes.bfloat16) # t662: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t664 = prims.mul(t661, t661) # t664: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t665 = prims.sum(t664, (2,)) # t665: \"cuda:0 f32[1, 512]\"\n",
+ " # t666 = prims.broadcast_in_dim(t665, [1, 512, 1], [0, 1]) # t666: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t667 = prims.div(t666, 4096.0) # t667: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t668 = prims.add(t667, 1e-05) # t668: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t669 = prims.rsqrt(t668) # t669: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t670 = prims.broadcast_in_dim(t669, (1, 512, 4096), (0, 1, 2)) # t670: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t671 = prims.mul(t661, t670) # t671: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t675 = prims.convert_element_type(t673, dtypes.float32) # t675: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t676 = prims.mul(t671, t675) # t676: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t677 = prims.convert_element_type(t676, dtypes.bfloat16) # t677: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t678 = torch.nn.functional.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t678 = ltorch.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t678 = prims.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t679 = torch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t679 = ltorch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t679 = prims.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t678\n",
+ " t680 = torch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t680 = ltorch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t680 = prims.transpose(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t679\n",
+ " (t681, t682, t683) = torch.split(t680, (1, 1, 1), 2)\n",
+ " # (t681, t682, t683) = ltorch.split(t680, (1, 1, 1), 2)\n",
+ " # t681 = prims.slice_prim(t680, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t681: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t682 = prims.slice_prim(t680, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t682: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t683 = prims.slice_prim(t680, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t683: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t680\n",
+ " t684 = torch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t684 = ltorch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t684 = prims.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t681\n",
+ " t685 = torch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t685 = ltorch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t685 = prims.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t682\n",
+ " t686 = torch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t686 = ltorch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t686 = prims.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t683\n",
+ " t687 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t687: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t702 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t702: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t717 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t717: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t684\n",
+ " t719 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t719: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t685\n",
+ " t688 = torch_slice_prim_impl(t687, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t688: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t689 = torch_slice_prim_impl(t687, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t689: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t703 = torch_slice_prim_impl(t702, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t703: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t704 = torch_slice_prim_impl(t702, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t704: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t692, t707] = nvFusion26(t687, t689, t702, t704)\n",
+ " # t690 = prims.convert_element_type(t689, dtypes.float32) # t690: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t691 = prims.neg(t690) # t691: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t692 = prims.convert_element_type(t691, dtypes.bfloat16) # t692: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t705 = prims.convert_element_type(t704, dtypes.float32) # t705: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t706 = prims.neg(t705) # t706: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t707 = prims.convert_element_type(t706, dtypes.bfloat16) # t707: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t689, t704\n",
+ " t708 = torch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t708 = ltorch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t708 = prims.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t707, t703\n",
+ " t693 = torch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t693 = ltorch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t693 = prims.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t692, t688\n",
+ " [t701, t716] = nvFusion27(t154, t157, t687, t693, t702, t708)\n",
+ " # t695 = prims.convert_element_type(t687, dtypes.float32) # t695: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t710 = prims.convert_element_type(t702, dtypes.float32) # t710: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t711 = prims.mul(t710, t154) # t711: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t713 = prims.convert_element_type(t708, dtypes.float32) # t713: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t714 = prims.mul(t713, t157) # t714: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t715 = prims.add(t711, t714) # t715: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t716 = prims.convert_element_type(t715, dtypes.bfloat16) # t716: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t696 = prims.mul(t695, t154) # t696: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t698 = prims.convert_element_type(t693, dtypes.float32) # t698: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t699 = prims.mul(t698, t157) # t699: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t700 = prims.add(t696, t699) # t700: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t701 = prims.convert_element_type(t700, dtypes.bfloat16) # t701: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t687, t693, t702, t708\n",
+ " t720 = torch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t720 = ltorch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t720 = prims.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t716, t719\n",
+ " t718 = torch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t718 = ltorch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t718 = prims.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t701, t717\n",
+ " (t721, t722, t723, t724, _, _, t725, t726, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t718, t720, t686, 0.0, True, scale=0.08838834764831843)\n",
+ " t728 = torch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t728 = ltorch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t728 = prims.transpose(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t729 = torch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t729 = ltorch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t729 = prims.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t728\n",
+ " t730 = torch.nn.functional.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t730 = ltorch.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t730 = prims.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t734, t741, t749] = nvFusion28(t662, t730, t745)\n",
+ " # t732 = prims.convert_element_type(t662, dtypes.float32) # t732: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t731 = prims.convert_element_type(t730, dtypes.float32) # t731: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t733 = prims.add(t731, t732) # t733: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t734 = prims.convert_element_type(t733, dtypes.bfloat16) # t734: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t736 = prims.mul(t733, t733) # t736: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t737 = prims.sum(t736, (2,)) # t737: \"cuda:0 f32[1, 512]\"\n",
+ " # t738 = prims.broadcast_in_dim(t737, [1, 512, 1], [0, 1]) # t738: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t739 = prims.div(t738, 4096.0) # t739: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t740 = prims.add(t739, 1e-05) # t740: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t741 = prims.rsqrt(t740) # t741: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t742 = prims.broadcast_in_dim(t741, (1, 512, 4096), (0, 1, 2)) # t742: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t743 = prims.mul(t733, t742) # t743: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t747 = prims.convert_element_type(t745, dtypes.float32) # t747: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t748 = prims.mul(t743, t747) # t748: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t750 = torch.nn.functional.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t750 = ltorch.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t750 = prims.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t751 = torch.nn.functional.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t751 = ltorch.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t751 = prims.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t765] = nvFusion29(t750, t751)\n",
+ " # t752 = prims.convert_element_type(t750, dtypes.float32) # t752: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t753 = prims.neg(t752) # t753: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t754 = prims.exp(t753) # t754: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t755 = prims.add(1.0, t754) # t755: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t756 = prims.reciprocal(t755) # t756: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t760 = prims.mul(t752, t756) # t760: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t763 = prims.convert_element_type(t751, dtypes.float32) # t763: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t764 = prims.mul(t760, t763) # t764: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t765 = prims.convert_element_type(t764, dtypes.bfloat16) # t765: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t766 = torch.nn.functional.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t766 = ltorch.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t766 = prims.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t770, t777, t785] = nvFusion30(t734, t766, t781)\n",
+ " # t768 = prims.convert_element_type(t734, dtypes.float32) # t768: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t767 = prims.convert_element_type(t766, dtypes.float32) # t767: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t769 = prims.add(t767, t768) # t769: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t770 = prims.convert_element_type(t769, dtypes.bfloat16) # t770: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t772 = prims.mul(t769, t769) # t772: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t773 = prims.sum(t772, (2,)) # t773: \"cuda:0 f32[1, 512]\"\n",
+ " # t774 = prims.broadcast_in_dim(t773, [1, 512, 1], [0, 1]) # t774: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t775 = prims.div(t774, 4096.0) # t775: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t776 = prims.add(t775, 1e-05) # t776: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t777 = prims.rsqrt(t776) # t777: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t778 = prims.broadcast_in_dim(t777, (1, 512, 4096), (0, 1, 2)) # t778: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t779 = prims.mul(t769, t778) # t779: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t783 = prims.convert_element_type(t781, dtypes.float32) # t783: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t784 = prims.mul(t779, t783) # t784: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t785 = prims.convert_element_type(t784, dtypes.bfloat16) # t785: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t786 = torch.nn.functional.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t786 = ltorch.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t786 = prims.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t787 = torch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t787 = ltorch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t787 = prims.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t786\n",
+ " t788 = torch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t788 = ltorch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t788 = prims.transpose(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t787\n",
+ " (t789, t790, t791) = torch.split(t788, (1, 1, 1), 2)\n",
+ " # (t789, t790, t791) = ltorch.split(t788, (1, 1, 1), 2)\n",
+ " # t789 = prims.slice_prim(t788, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t789: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t790 = prims.slice_prim(t788, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t790: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t791 = prims.slice_prim(t788, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t791: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t788\n",
+ " t792 = torch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t792 = ltorch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t792 = prims.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t789\n",
+ " t793 = torch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t793 = ltorch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t793 = prims.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t790\n",
+ " t794 = torch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t794 = ltorch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t794 = prims.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t791\n",
+ " t795 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t795: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t810 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t810: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t825 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t825: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t792\n",
+ " t827 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t827: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t793\n",
+ " t796 = torch_slice_prim_impl(t795, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t796: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t797 = torch_slice_prim_impl(t795, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t797: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t811 = torch_slice_prim_impl(t810, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t811: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t812 = torch_slice_prim_impl(t810, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t812: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t800, t815] = nvFusion31(t795, t797, t810, t812)\n",
+ " # t798 = prims.convert_element_type(t797, dtypes.float32) # t798: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t799 = prims.neg(t798) # t799: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t800 = prims.convert_element_type(t799, dtypes.bfloat16) # t800: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t813 = prims.convert_element_type(t812, dtypes.float32) # t813: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t814 = prims.neg(t813) # t814: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t815 = prims.convert_element_type(t814, dtypes.bfloat16) # t815: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t797, t812\n",
+ " t816 = torch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t816 = ltorch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t816 = prims.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t815, t811\n",
+ " t801 = torch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t801 = ltorch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t801 = prims.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t800, t796\n",
+ " [t809, t824] = nvFusion32(t154, t157, t795, t801, t810, t816)\n",
+ " # t803 = prims.convert_element_type(t795, dtypes.float32) # t803: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t818 = prims.convert_element_type(t810, dtypes.float32) # t818: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t819 = prims.mul(t818, t154) # t819: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t821 = prims.convert_element_type(t816, dtypes.float32) # t821: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t822 = prims.mul(t821, t157) # t822: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t823 = prims.add(t819, t822) # t823: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t824 = prims.convert_element_type(t823, dtypes.bfloat16) # t824: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t804 = prims.mul(t803, t154) # t804: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t806 = prims.convert_element_type(t801, dtypes.float32) # t806: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t807 = prims.mul(t806, t157) # t807: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t808 = prims.add(t804, t807) # t808: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t809 = prims.convert_element_type(t808, dtypes.bfloat16) # t809: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t795, t801, t810, t816\n",
+ " t828 = torch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t828 = ltorch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t828 = prims.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t824, t827\n",
+ " t826 = torch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t826 = ltorch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t826 = prims.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t809, t825\n",
+ " (t829, t830, t831, t832, _, _, t833, t834, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t826, t828, t794, 0.0, True, scale=0.08838834764831843)\n",
+ " t836 = torch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t836 = ltorch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t836 = prims.transpose(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t837 = torch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t837 = ltorch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t837 = prims.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t836\n",
+ " t838 = torch.nn.functional.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t838 = ltorch.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t838 = prims.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t842, t849, t857] = nvFusion33(t770, t838, t853)\n",
+ " # t840 = prims.convert_element_type(t770, dtypes.float32) # t840: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t839 = prims.convert_element_type(t838, dtypes.float32) # t839: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t841 = prims.add(t839, t840) # t841: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t842 = prims.convert_element_type(t841, dtypes.bfloat16) # t842: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t844 = prims.mul(t841, t841) # t844: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t845 = prims.sum(t844, (2,)) # t845: \"cuda:0 f32[1, 512]\"\n",
+ " # t846 = prims.broadcast_in_dim(t845, [1, 512, 1], [0, 1]) # t846: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t847 = prims.div(t846, 4096.0) # t847: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t848 = prims.add(t847, 1e-05) # t848: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t849 = prims.rsqrt(t848) # t849: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t851 = prims.mul(t841, t850) # t851: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t855 = prims.convert_element_type(t853, dtypes.float32) # t855: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t856 = prims.mul(t851, t855) # t856: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t857 = prims.convert_element_type(t856, dtypes.bfloat16) # t857: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t858 = torch.nn.functional.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t858 = ltorch.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t858 = prims.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t859 = torch.nn.functional.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t859 = ltorch.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t859 = prims.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t873] = nvFusion34(t858, t859)\n",
+ " # t860 = prims.convert_element_type(t858, dtypes.float32) # t860: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t861 = prims.neg(t860) # t861: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t862 = prims.exp(t861) # t862: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t863 = prims.add(1.0, t862) # t863: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t864 = prims.reciprocal(t863) # t864: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t868 = prims.mul(t860, t864) # t868: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t871 = prims.convert_element_type(t859, dtypes.float32) # t871: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t872 = prims.mul(t868, t871) # t872: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t873 = prims.convert_element_type(t872, dtypes.bfloat16) # t873: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t874 = torch.nn.functional.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t874 = ltorch.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t874 = prims.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t878, t885, t893] = nvFusion35(t842, t874, t889)\n",
+ " # t876 = prims.convert_element_type(t842, dtypes.float32) # t876: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t875 = prims.convert_element_type(t874, dtypes.float32) # t875: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t877 = prims.add(t875, t876) # t877: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t878 = prims.convert_element_type(t877, dtypes.bfloat16) # t878: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t880 = prims.mul(t877, t877) # t880: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t881 = prims.sum(t880, (2,)) # t881: \"cuda:0 f32[1, 512]\"\n",
+ " # t882 = prims.broadcast_in_dim(t881, [1, 512, 1], [0, 1]) # t882: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t883 = prims.div(t882, 4096.0) # t883: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t884 = prims.add(t883, 1e-05) # t884: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t885 = prims.rsqrt(t884) # t885: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t886 = prims.broadcast_in_dim(t885, (1, 512, 4096), (0, 1, 2)) # t886: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t887 = prims.mul(t877, t886) # t887: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t891 = prims.convert_element_type(t889, dtypes.float32) # t891: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t892 = prims.mul(t887, t891) # t892: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t893 = prims.convert_element_type(t892, dtypes.bfloat16) # t893: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t894 = torch.nn.functional.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t894 = ltorch.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t894 = prims.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t895 = torch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t895 = ltorch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t895 = prims.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t894\n",
+ " t896 = torch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t896 = ltorch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t896 = prims.transpose(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t895\n",
+ " (t897, t898, t899) = torch.split(t896, (1, 1, 1), 2)\n",
+ " # (t897, t898, t899) = ltorch.split(t896, (1, 1, 1), 2)\n",
+ " # t897 = prims.slice_prim(t896, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t897: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t898 = prims.slice_prim(t896, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t898: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t899 = prims.slice_prim(t896, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t899: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t896\n",
+ " t900 = torch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t900 = ltorch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t900 = prims.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t897\n",
+ " t901 = torch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t901 = ltorch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t901 = prims.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t898\n",
+ " t902 = torch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t902 = ltorch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t902 = prims.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t899\n",
+ " t935 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t935: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " t903 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t903: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t918 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t901\n",
+ " t933 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t933: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t900\n",
+ " t904 = torch_slice_prim_impl(t903, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t904: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t905 = torch_slice_prim_impl(t903, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t905: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t919 = torch_slice_prim_impl(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t920 = torch_slice_prim_impl(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t908, t923] = nvFusion36(t903, t905, t918, t920)\n",
+ " # t906 = prims.convert_element_type(t905, dtypes.float32) # t906: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t907 = prims.neg(t906) # t907: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t908 = prims.convert_element_type(t907, dtypes.bfloat16) # t908: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t921 = prims.convert_element_type(t920, dtypes.float32) # t921: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t922 = prims.neg(t921) # t922: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t905, t920\n",
+ " t924 = torch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t924 = ltorch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t924 = prims.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t923, t919\n",
+ " t909 = torch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t909 = ltorch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t909 = prims.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t908, t904\n",
+ " [t917, t932] = nvFusion37(t154, t157, t903, t909, t918, t924)\n",
+ " # t911 = prims.convert_element_type(t903, dtypes.float32) # t911: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t926 = prims.convert_element_type(t918, dtypes.float32) # t926: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t927 = prims.mul(t926, t154) # t927: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t929 = prims.convert_element_type(t924, dtypes.float32) # t929: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t930 = prims.mul(t929, t157) # t930: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t931 = prims.add(t927, t930) # t931: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t932 = prims.convert_element_type(t931, dtypes.bfloat16) # t932: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t912 = prims.mul(t911, t154) # t912: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t914 = prims.convert_element_type(t909, dtypes.float32) # t914: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t915 = prims.mul(t914, t157) # t915: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t916 = prims.add(t912, t915) # t916: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t903, t909, t918, t924\n",
+ " t936 = torch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t936 = ltorch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t936 = prims.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t932, t935\n",
+ " t934 = torch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t934 = ltorch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t934 = prims.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t917, t933\n",
+ " (t937, t938, t939, t940, _, _, t941, t942, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t934, t936, t902, 0.0, True, scale=0.08838834764831843)\n",
+ " t944 = torch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t944 = ltorch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t944 = prims.transpose(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t945 = torch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t945 = ltorch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t945 = prims.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t944\n",
+ " t946 = torch.nn.functional.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t946 = ltorch.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t946 = prims.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t950, t957, t965] = nvFusion38(t878, t946, t961)\n",
+ " # t948 = prims.convert_element_type(t878, dtypes.float32) # t948: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t947 = prims.convert_element_type(t946, dtypes.float32) # t947: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t949 = prims.add(t947, t948) # t949: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t950 = prims.convert_element_type(t949, dtypes.bfloat16) # t950: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t952 = prims.mul(t949, t949) # t952: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t953 = prims.sum(t952, (2,)) # t953: \"cuda:0 f32[1, 512]\"\n",
+ " # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t955 = prims.div(t954, 4096.0) # t955: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t956 = prims.add(t955, 1e-05) # t956: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t957 = prims.rsqrt(t956) # t957: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t958 = prims.broadcast_in_dim(t957, (1, 512, 4096), (0, 1, 2)) # t958: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t959 = prims.mul(t949, t958) # t959: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t963 = prims.convert_element_type(t961, dtypes.float32) # t963: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t964 = prims.mul(t959, t963) # t964: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t965 = prims.convert_element_type(t964, dtypes.bfloat16) # t965: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t967 = torch.nn.functional.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t967 = ltorch.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t967 = prims.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t966 = torch.nn.functional.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t966 = ltorch.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t966 = prims.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t981] = nvFusion39(t966, t967)\n",
+ " # t968 = prims.convert_element_type(t966, dtypes.float32) # t968: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t969 = prims.neg(t968) # t969: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t970 = prims.exp(t969) # t970: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t971 = prims.add(1.0, t970) # t971: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t972 = prims.reciprocal(t971) # t972: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t976 = prims.mul(t968, t972) # t976: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t979 = prims.convert_element_type(t967, dtypes.float32) # t979: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t980 = prims.mul(t976, t979) # t980: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t982 = torch.nn.functional.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t982 = ltorch.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t982 = prims.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1001, t986, t993] = nvFusion40(t950, t982, t997)\n",
+ " # t984 = prims.convert_element_type(t950, dtypes.float32) # t984: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t983 = prims.convert_element_type(t982, dtypes.float32) # t983: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t985 = prims.add(t983, t984) # t985: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t988 = prims.mul(t985, t985) # t988: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t989 = prims.sum(t988, (2,)) # t989: \"cuda:0 f32[1, 512]\"\n",
+ " # t990 = prims.broadcast_in_dim(t989, [1, 512, 1], [0, 1]) # t990: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t991 = prims.div(t990, 4096.0) # t991: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t992 = prims.add(t991, 1e-05) # t992: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t993 = prims.rsqrt(t992) # t993: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t994 = prims.broadcast_in_dim(t993, (1, 512, 4096), (0, 1, 2)) # t994: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t995 = prims.mul(t985, t994) # t995: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t999 = prims.convert_element_type(t997, dtypes.float32) # t999: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1000 = prims.mul(t995, t999) # t1000: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1001 = prims.convert_element_type(t1000, dtypes.bfloat16) # t1001: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1002 = torch.nn.functional.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1002 = ltorch.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1002 = prims.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1003 = torch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1003 = ltorch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1003 = prims.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1002\n",
+ " t1004 = torch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1004 = ltorch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1004 = prims.transpose(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1003\n",
+ " (t1005, t1006, t1007) = torch.split(t1004, (1, 1, 1), 2)\n",
+ " # (t1005, t1006, t1007) = ltorch.split(t1004, (1, 1, 1), 2)\n",
+ " # t1005 = prims.slice_prim(t1004, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1005: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1006 = prims.slice_prim(t1004, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1006: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1007 = prims.slice_prim(t1004, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1007: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1004\n",
+ " t1008 = torch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1008 = ltorch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1008 = prims.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1005\n",
+ " t1009 = torch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1009 = ltorch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1009 = prims.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1006\n",
+ " t1010 = torch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1010 = ltorch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1010 = prims.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1007\n",
+ " t1026 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1026: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1041 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1041: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " t1043 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1043: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1009\n",
+ " t1011 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1011: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1008\n",
+ " t1027 = torch_slice_prim_impl(t1026, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1027: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1028 = torch_slice_prim_impl(t1026, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1028: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1013 = torch_slice_prim_impl(t1011, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1013: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1012 = torch_slice_prim_impl(t1011, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1012: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1016, t1031] = nvFusion41(t1011, t1013, t1026, t1028)\n",
+ " # t1014 = prims.convert_element_type(t1013, dtypes.float32) # t1014: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1015 = prims.neg(t1014) # t1015: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1016 = prims.convert_element_type(t1015, dtypes.bfloat16) # t1016: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1029 = prims.convert_element_type(t1028, dtypes.float32) # t1029: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1030 = prims.neg(t1029) # t1030: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1031 = prims.convert_element_type(t1030, dtypes.bfloat16) # t1031: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1013, t1028\n",
+ " t1032 = torch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1032 = ltorch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1032 = prims.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1031, t1027\n",
+ " t1017 = torch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1017 = ltorch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1017 = prims.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1016, t1012\n",
+ " [t1025, t1040] = nvFusion42(t1011, t1017, t1026, t1032, t154, t157)\n",
+ " # t1019 = prims.convert_element_type(t1011, dtypes.float32) # t1019: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1034 = prims.convert_element_type(t1026, dtypes.float32) # t1034: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1020 = prims.mul(t1019, t154) # t1020: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1022 = prims.convert_element_type(t1017, dtypes.float32) # t1022: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1023 = prims.mul(t1022, t157) # t1023: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1024 = prims.add(t1020, t1023) # t1024: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1035 = prims.mul(t1034, t154) # t1035: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1037 = prims.convert_element_type(t1032, dtypes.float32) # t1037: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1038 = prims.mul(t1037, t157) # t1038: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1039 = prims.add(t1035, t1038) # t1039: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1040 = prims.convert_element_type(t1039, dtypes.bfloat16) # t1040: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1011, t1017, t1026, t1032\n",
+ " t1042 = torch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1042 = ltorch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1042 = prims.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1025, t1041\n",
+ " t1044 = torch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1044 = ltorch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1044 = prims.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1040, t1043\n",
+ " (t1045, t1046, t1047, t1048, _, _, t1049, t1050, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1042, t1044, t1010, 0.0, True, scale=0.08838834764831843)\n",
+ " t1052 = torch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1052 = ltorch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1052 = prims.transpose(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1053 = torch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1053 = ltorch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1053 = prims.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1052\n",
+ " t1054 = torch.nn.functional.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1054 = ltorch.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1054 = prims.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1058, t1065, t1073] = nvFusion43(t1054, t1069, t986)\n",
+ " # t1056 = prims.convert_element_type(t986, dtypes.float32) # t1056: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1055 = prims.convert_element_type(t1054, dtypes.float32) # t1055: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1057 = prims.add(t1055, t1056) # t1057: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1058 = prims.convert_element_type(t1057, dtypes.bfloat16) # t1058: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1060 = prims.mul(t1057, t1057) # t1060: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1061 = prims.sum(t1060, (2,)) # t1061: \"cuda:0 f32[1, 512]\"\n",
+ " # t1062 = prims.broadcast_in_dim(t1061, [1, 512, 1], [0, 1]) # t1062: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1063 = prims.div(t1062, 4096.0) # t1063: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1064 = prims.add(t1063, 1e-05) # t1064: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1065 = prims.rsqrt(t1064) # t1065: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1066 = prims.broadcast_in_dim(t1065, (1, 512, 4096), (0, 1, 2)) # t1066: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1067 = prims.mul(t1057, t1066) # t1067: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1071 = prims.convert_element_type(t1069, dtypes.float32) # t1071: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1072 = prims.mul(t1067, t1071) # t1072: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1073 = prims.convert_element_type(t1072, dtypes.bfloat16) # t1073: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1074 = torch.nn.functional.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1074 = ltorch.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1074 = prims.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1075 = torch.nn.functional.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1075 = ltorch.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1075 = prims.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1089] = nvFusion44(t1074, t1075)\n",
+ " # t1076 = prims.convert_element_type(t1074, dtypes.float32) # t1076: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1077 = prims.neg(t1076) # t1077: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1078 = prims.exp(t1077) # t1078: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1079 = prims.add(1.0, t1078) # t1079: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1080 = prims.reciprocal(t1079) # t1080: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1084 = prims.mul(t1076, t1080) # t1084: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1087 = prims.convert_element_type(t1075, dtypes.float32) # t1087: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1088 = prims.mul(t1084, t1087) # t1088: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1089 = prims.convert_element_type(t1088, dtypes.bfloat16) # t1089: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1090 = torch.nn.functional.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1090 = ltorch.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1090 = prims.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1094, t1101, t1109] = nvFusion45(t1058, t1090, t1105)\n",
+ " # t1092 = prims.convert_element_type(t1058, dtypes.float32) # t1092: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1093 = prims.add(t1091, t1092) # t1093: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1094 = prims.convert_element_type(t1093, dtypes.bfloat16) # t1094: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1096 = prims.mul(t1093, t1093) # t1096: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1097 = prims.sum(t1096, (2,)) # t1097: \"cuda:0 f32[1, 512]\"\n",
+ " # t1098 = prims.broadcast_in_dim(t1097, [1, 512, 1], [0, 1]) # t1098: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1099 = prims.div(t1098, 4096.0) # t1099: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1100 = prims.add(t1099, 1e-05) # t1100: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1101 = prims.rsqrt(t1100) # t1101: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1102 = prims.broadcast_in_dim(t1101, (1, 512, 4096), (0, 1, 2)) # t1102: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1103 = prims.mul(t1093, t1102) # t1103: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1107 = prims.convert_element_type(t1105, dtypes.float32) # t1107: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1108 = prims.mul(t1103, t1107) # t1108: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1109 = prims.convert_element_type(t1108, dtypes.bfloat16) # t1109: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1110 = torch.nn.functional.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1110 = ltorch.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1110 = prims.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1111 = torch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1111 = ltorch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1111 = prims.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1110\n",
+ " t1112 = torch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1112 = ltorch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1112 = prims.transpose(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1111\n",
+ " (t1113, t1114, t1115) = torch.split(t1112, (1, 1, 1), 2)\n",
+ " # (t1113, t1114, t1115) = ltorch.split(t1112, (1, 1, 1), 2)\n",
+ " # t1113 = prims.slice_prim(t1112, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1113: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1114 = prims.slice_prim(t1112, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1114: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1115 = prims.slice_prim(t1112, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1115: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1112\n",
+ " t1116 = torch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1116 = ltorch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1116 = prims.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1113\n",
+ " t1117 = torch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1117 = ltorch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1117 = prims.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1114\n",
+ " t1118 = torch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1118 = ltorch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1118 = prims.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1115\n",
+ " t1119 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1119: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1134 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1134: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1149 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1149: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1116\n",
+ " t1151 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1151: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1117\n",
+ " t1120 = torch_slice_prim_impl(t1119, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1120: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1121 = torch_slice_prim_impl(t1119, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1121: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1136 = torch_slice_prim_impl(t1134, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1136: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1135 = torch_slice_prim_impl(t1134, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1135: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1124, t1139] = nvFusion46(t1119, t1121, t1134, t1136)\n",
+ " # t1122 = prims.convert_element_type(t1121, dtypes.float32) # t1122: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1123 = prims.neg(t1122) # t1123: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1137 = prims.convert_element_type(t1136, dtypes.float32) # t1137: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1138 = prims.neg(t1137) # t1138: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1121, t1136\n",
+ " t1125 = torch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1125 = ltorch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1125 = prims.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1124, t1120\n",
+ " t1140 = torch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1140 = ltorch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1140 = prims.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1139, t1135\n",
+ " [t1133, t1148] = nvFusion47(t1119, t1125, t1134, t1140, t154, t157)\n",
+ " # t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1142 = prims.convert_element_type(t1134, dtypes.float32) # t1142: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1128 = prims.mul(t1127, t154) # t1128: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1130 = prims.convert_element_type(t1125, dtypes.float32) # t1130: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1131 = prims.mul(t1130, t157) # t1131: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1132 = prims.add(t1128, t1131) # t1132: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1133 = prims.convert_element_type(t1132, dtypes.bfloat16) # t1133: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1143 = prims.mul(t1142, t154) # t1143: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1145 = prims.convert_element_type(t1140, dtypes.float32) # t1145: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1146 = prims.mul(t1145, t157) # t1146: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1147 = prims.add(t1143, t1146) # t1147: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1148 = prims.convert_element_type(t1147, dtypes.bfloat16) # t1148: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1119, t1125, t1134, t1140\n",
+ " t1152 = torch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1152 = ltorch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1152 = prims.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1148, t1151\n",
+ " t1150 = torch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1150 = ltorch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1150 = prims.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1133, t1149\n",
+ " (t1153, t1154, t1155, t1156, _, _, t1157, t1158, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1150, t1152, t1118, 0.0, True, scale=0.08838834764831843)\n",
+ " t1160 = torch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1160 = ltorch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1160 = prims.transpose(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1161 = torch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1161 = ltorch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1161 = prims.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1160\n",
+ " t1162 = torch.nn.functional.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1162 = ltorch.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1162 = prims.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1166, t1173, t1181] = nvFusion48(t1094, t1162, t1177)\n",
+ " # t1164 = prims.convert_element_type(t1094, dtypes.float32) # t1164: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1163 = prims.convert_element_type(t1162, dtypes.float32) # t1163: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1165 = prims.add(t1163, t1164) # t1165: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1166 = prims.convert_element_type(t1165, dtypes.bfloat16) # t1166: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1168 = prims.mul(t1165, t1165) # t1168: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1169 = prims.sum(t1168, (2,)) # t1169: \"cuda:0 f32[1, 512]\"\n",
+ " # t1170 = prims.broadcast_in_dim(t1169, [1, 512, 1], [0, 1]) # t1170: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1171 = prims.div(t1170, 4096.0) # t1171: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1172 = prims.add(t1171, 1e-05) # t1172: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1173 = prims.rsqrt(t1172) # t1173: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1174 = prims.broadcast_in_dim(t1173, (1, 512, 4096), (0, 1, 2)) # t1174: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1175 = prims.mul(t1165, t1174) # t1175: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1179 = prims.convert_element_type(t1177, dtypes.float32) # t1179: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1180 = prims.mul(t1175, t1179) # t1180: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1181 = prims.convert_element_type(t1180, dtypes.bfloat16) # t1181: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1182 = torch.nn.functional.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1182 = ltorch.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1182 = prims.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1183 = torch.nn.functional.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1183 = ltorch.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1183 = prims.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1197] = nvFusion49(t1182, t1183)\n",
+ " # t1184 = prims.convert_element_type(t1182, dtypes.float32) # t1184: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1185 = prims.neg(t1184) # t1185: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1186 = prims.exp(t1185) # t1186: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1187 = prims.add(1.0, t1186) # t1187: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1188 = prims.reciprocal(t1187) # t1188: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1192 = prims.mul(t1184, t1188) # t1192: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1195 = prims.convert_element_type(t1183, dtypes.float32) # t1195: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1196 = prims.mul(t1192, t1195) # t1196: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1198 = torch.nn.functional.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1198 = ltorch.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1198 = prims.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1202, t1209, t1217] = nvFusion50(t1166, t1198, t1213)\n",
+ " # t1200 = prims.convert_element_type(t1166, dtypes.float32) # t1200: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1199 = prims.convert_element_type(t1198, dtypes.float32) # t1199: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1201 = prims.add(t1199, t1200) # t1201: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1202 = prims.convert_element_type(t1201, dtypes.bfloat16) # t1202: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1204 = prims.mul(t1201, t1201) # t1204: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1205 = prims.sum(t1204, (2,)) # t1205: \"cuda:0 f32[1, 512]\"\n",
+ " # t1206 = prims.broadcast_in_dim(t1205, [1, 512, 1], [0, 1]) # t1206: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1207 = prims.div(t1206, 4096.0) # t1207: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1208 = prims.add(t1207, 1e-05) # t1208: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1209 = prims.rsqrt(t1208) # t1209: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1210 = prims.broadcast_in_dim(t1209, (1, 512, 4096), (0, 1, 2)) # t1210: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1211 = prims.mul(t1201, t1210) # t1211: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1215 = prims.convert_element_type(t1213, dtypes.float32) # t1215: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1216 = prims.mul(t1211, t1215) # t1216: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1217 = prims.convert_element_type(t1216, dtypes.bfloat16) # t1217: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1218 = torch.nn.functional.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1218 = ltorch.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1218 = prims.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1219 = torch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1219 = ltorch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1219 = prims.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1218\n",
+ " t1220 = torch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1220 = ltorch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1220 = prims.transpose(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1219\n",
+ " (t1221, t1222, t1223) = torch.split(t1220, (1, 1, 1), 2)\n",
+ " # (t1221, t1222, t1223) = ltorch.split(t1220, (1, 1, 1), 2)\n",
+ " # t1221 = prims.slice_prim(t1220, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1221: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1222 = prims.slice_prim(t1220, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1222: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1223 = prims.slice_prim(t1220, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1223: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1220\n",
+ " t1224 = torch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1224 = ltorch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1224 = prims.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1221\n",
+ " t1225 = torch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1225 = ltorch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1225 = prims.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1222\n",
+ " t1226 = torch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1226 = ltorch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1226 = prims.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1223\n",
+ " t1227 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1227: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1242 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1242: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1257 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1257: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1224\n",
+ " t1259 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1259: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1225\n",
+ " t1228 = torch_slice_prim_impl(t1227, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1228: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1229 = torch_slice_prim_impl(t1227, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1229: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1243 = torch_slice_prim_impl(t1242, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1243: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1244 = torch_slice_prim_impl(t1242, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1244: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1232, t1247] = nvFusion51(t1227, t1229, t1242, t1244)\n",
+ " # t1230 = prims.convert_element_type(t1229, dtypes.float32) # t1230: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1231 = prims.neg(t1230) # t1231: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1232 = prims.convert_element_type(t1231, dtypes.bfloat16) # t1232: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1245 = prims.convert_element_type(t1244, dtypes.float32) # t1245: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1246 = prims.neg(t1245) # t1246: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1247 = prims.convert_element_type(t1246, dtypes.bfloat16) # t1247: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1229, t1244\n",
+ " t1233 = torch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1233 = ltorch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1233 = prims.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1232, t1228\n",
+ " t1248 = torch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1248 = ltorch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1248 = prims.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1247, t1243\n",
+ " [t1241, t1256] = nvFusion52(t1227, t1233, t1242, t1248, t154, t157)\n",
+ " # t1235 = prims.convert_element_type(t1227, dtypes.float32) # t1235: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1250 = prims.convert_element_type(t1242, dtypes.float32) # t1250: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1236 = prims.mul(t1235, t154) # t1236: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1238 = prims.convert_element_type(t1233, dtypes.float32) # t1238: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1239 = prims.mul(t1238, t157) # t1239: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1240 = prims.add(t1236, t1239) # t1240: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1241 = prims.convert_element_type(t1240, dtypes.bfloat16) # t1241: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1251 = prims.mul(t1250, t154) # t1251: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1253 = prims.convert_element_type(t1248, dtypes.float32) # t1253: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1254 = prims.mul(t1253, t157) # t1254: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1255 = prims.add(t1251, t1254) # t1255: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1256 = prims.convert_element_type(t1255, dtypes.bfloat16) # t1256: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1227, t1233, t1242, t1248\n",
+ " t1258 = torch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1258 = ltorch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1258 = prims.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1241, t1257\n",
+ " t1260 = torch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1260 = ltorch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1260 = prims.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1256, t1259\n",
+ " (t1261, t1262, t1263, t1264, _, _, t1265, t1266, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1258, t1260, t1226, 0.0, True, scale=0.08838834764831843)\n",
+ " t1268 = torch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1268 = ltorch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1268 = prims.transpose(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1269 = torch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1269 = ltorch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1269 = prims.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1268\n",
+ " t1270 = torch.nn.functional.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1270 = ltorch.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1270 = prims.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1274, t1281, t1289] = nvFusion53(t1202, t1270, t1285)\n",
+ " # t1272 = prims.convert_element_type(t1202, dtypes.float32) # t1272: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1271 = prims.convert_element_type(t1270, dtypes.float32) # t1271: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1273 = prims.add(t1271, t1272) # t1273: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1274 = prims.convert_element_type(t1273, dtypes.bfloat16) # t1274: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1276 = prims.mul(t1273, t1273) # t1276: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1277 = prims.sum(t1276, (2,)) # t1277: \"cuda:0 f32[1, 512]\"\n",
+ " # t1278 = prims.broadcast_in_dim(t1277, [1, 512, 1], [0, 1]) # t1278: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1279 = prims.div(t1278, 4096.0) # t1279: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1280 = prims.add(t1279, 1e-05) # t1280: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1281 = prims.rsqrt(t1280) # t1281: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1282 = prims.broadcast_in_dim(t1281, (1, 512, 4096), (0, 1, 2)) # t1282: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1283 = prims.mul(t1273, t1282) # t1283: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1287 = prims.convert_element_type(t1285, dtypes.float32) # t1287: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1288 = prims.mul(t1283, t1287) # t1288: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1289 = prims.convert_element_type(t1288, dtypes.bfloat16) # t1289: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1290 = torch.nn.functional.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1290 = ltorch.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1290 = prims.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1291 = torch.nn.functional.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1291 = ltorch.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1291 = prims.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1305] = nvFusion54(t1290, t1291)\n",
+ " # t1292 = prims.convert_element_type(t1290, dtypes.float32) # t1292: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1293 = prims.neg(t1292) # t1293: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1294 = prims.exp(t1293) # t1294: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1295 = prims.add(1.0, t1294) # t1295: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1296 = prims.reciprocal(t1295) # t1296: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1300 = prims.mul(t1292, t1296) # t1300: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1303 = prims.convert_element_type(t1291, dtypes.float32) # t1303: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1304 = prims.mul(t1300, t1303) # t1304: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1305 = prims.convert_element_type(t1304, dtypes.bfloat16) # t1305: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1306 = torch.nn.functional.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1306 = ltorch.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1306 = prims.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1310, t1317, t1325] = nvFusion55(t1274, t1306, t1321)\n",
+ " # t1308 = prims.convert_element_type(t1274, dtypes.float32) # t1308: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1307 = prims.convert_element_type(t1306, dtypes.float32) # t1307: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1309 = prims.add(t1307, t1308) # t1309: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1310 = prims.convert_element_type(t1309, dtypes.bfloat16) # t1310: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1312 = prims.mul(t1309, t1309) # t1312: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1313 = prims.sum(t1312, (2,)) # t1313: \"cuda:0 f32[1, 512]\"\n",
+ " # t1314 = prims.broadcast_in_dim(t1313, [1, 512, 1], [0, 1]) # t1314: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1315 = prims.div(t1314, 4096.0) # t1315: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1316 = prims.add(t1315, 1e-05) # t1316: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1317 = prims.rsqrt(t1316) # t1317: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1318 = prims.broadcast_in_dim(t1317, (1, 512, 4096), (0, 1, 2)) # t1318: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1319 = prims.mul(t1309, t1318) # t1319: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1323 = prims.convert_element_type(t1321, dtypes.float32) # t1323: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1324 = prims.mul(t1319, t1323) # t1324: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1325 = prims.convert_element_type(t1324, dtypes.bfloat16) # t1325: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1326 = torch.nn.functional.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1326 = ltorch.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1326 = prims.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1327 = torch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1327 = ltorch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1327 = prims.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1326\n",
+ " t1328 = torch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1328 = ltorch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1328 = prims.transpose(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1327\n",
+ " (t1329, t1330, t1331) = torch.split(t1328, (1, 1, 1), 2)\n",
+ " # (t1329, t1330, t1331) = ltorch.split(t1328, (1, 1, 1), 2)\n",
+ " # t1329 = prims.slice_prim(t1328, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1329: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1330 = prims.slice_prim(t1328, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1330: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1331 = prims.slice_prim(t1328, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1331: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1328\n",
+ " t1332 = torch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1332 = ltorch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1332 = prims.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1329\n",
+ " t1333 = torch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1333 = ltorch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1333 = prims.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1330\n",
+ " t1334 = torch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1334 = ltorch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1334 = prims.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1331\n",
+ " t1335 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1335: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1350 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1350: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1365 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1365: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1332\n",
+ " t1367 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1367: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1333\n",
+ " t1336 = torch_slice_prim_impl(t1335, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1336: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1337 = torch_slice_prim_impl(t1335, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1351 = torch_slice_prim_impl(t1350, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1351: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1352 = torch_slice_prim_impl(t1350, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1352: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1340, t1355] = nvFusion56(t1335, t1337, t1350, t1352)\n",
+ " # t1338 = prims.convert_element_type(t1337, dtypes.float32) # t1338: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1339 = prims.neg(t1338) # t1339: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1340 = prims.convert_element_type(t1339, dtypes.bfloat16) # t1340: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1353 = prims.convert_element_type(t1352, dtypes.float32) # t1353: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1354 = prims.neg(t1353) # t1354: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1355 = prims.convert_element_type(t1354, dtypes.bfloat16) # t1355: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1337, t1352\n",
+ " t1341 = torch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1341 = ltorch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1341 = prims.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1340, t1336\n",
+ " t1356 = torch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1356 = ltorch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1356 = prims.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1355, t1351\n",
+ " [t1349, t1364] = nvFusion57(t1335, t1341, t1350, t1356, t154, t157)\n",
+ " # t1343 = prims.convert_element_type(t1335, dtypes.float32) # t1343: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1358 = prims.convert_element_type(t1350, dtypes.float32) # t1358: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1344 = prims.mul(t1343, t154) # t1344: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1346 = prims.convert_element_type(t1341, dtypes.float32) # t1346: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1347 = prims.mul(t1346, t157) # t1347: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1348 = prims.add(t1344, t1347) # t1348: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1349 = prims.convert_element_type(t1348, dtypes.bfloat16) # t1349: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1359 = prims.mul(t1358, t154) # t1359: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1361 = prims.convert_element_type(t1356, dtypes.float32) # t1361: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1362 = prims.mul(t1361, t157) # t1362: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1363 = prims.add(t1359, t1362) # t1363: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1364 = prims.convert_element_type(t1363, dtypes.bfloat16) # t1364: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1335, t1341, t1350, t1356\n",
+ " t1366 = torch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1366 = ltorch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1366 = prims.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1349, t1365\n",
+ " t1368 = torch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1368 = ltorch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1368 = prims.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1364, t1367\n",
+ " (t1369, t1370, t1371, t1372, _, _, t1373, t1374, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1366, t1368, t1334, 0.0, True, scale=0.08838834764831843)\n",
+ " t1376 = torch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1376 = ltorch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1376 = prims.transpose(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1377 = torch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1377 = ltorch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1377 = prims.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1376\n",
+ " t1378 = torch.nn.functional.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1378 = ltorch.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1378 = prims.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1382, t1389, t1397] = nvFusion58(t1310, t1378, t1393)\n",
+ " # t1380 = prims.convert_element_type(t1310, dtypes.float32) # t1380: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1379 = prims.convert_element_type(t1378, dtypes.float32) # t1379: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1381 = prims.add(t1379, t1380) # t1381: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1382 = prims.convert_element_type(t1381, dtypes.bfloat16) # t1382: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1384 = prims.mul(t1381, t1381) # t1384: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1385 = prims.sum(t1384, (2,)) # t1385: \"cuda:0 f32[1, 512]\"\n",
+ " # t1386 = prims.broadcast_in_dim(t1385, [1, 512, 1], [0, 1]) # t1386: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1387 = prims.div(t1386, 4096.0) # t1387: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1388 = prims.add(t1387, 1e-05) # t1388: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1389 = prims.rsqrt(t1388) # t1389: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1390 = prims.broadcast_in_dim(t1389, (1, 512, 4096), (0, 1, 2)) # t1390: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1391 = prims.mul(t1381, t1390) # t1391: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1395 = prims.convert_element_type(t1393, dtypes.float32) # t1395: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1396 = prims.mul(t1391, t1395) # t1396: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1397 = prims.convert_element_type(t1396, dtypes.bfloat16) # t1397: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1398 = torch.nn.functional.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1398 = ltorch.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1398 = prims.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1399 = torch.nn.functional.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1399 = ltorch.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1399 = prims.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1413] = nvFusion59(t1398, t1399)\n",
+ " # t1400 = prims.convert_element_type(t1398, dtypes.float32) # t1400: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1401 = prims.neg(t1400) # t1401: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1402 = prims.exp(t1401) # t1402: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1403 = prims.add(1.0, t1402) # t1403: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1404 = prims.reciprocal(t1403) # t1404: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1408 = prims.mul(t1400, t1404) # t1408: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1411 = prims.convert_element_type(t1399, dtypes.float32) # t1411: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1412 = prims.mul(t1408, t1411) # t1412: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1413 = prims.convert_element_type(t1412, dtypes.bfloat16) # t1413: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1414 = torch.nn.functional.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1414 = ltorch.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1414 = prims.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1418, t1425, t1433] = nvFusion60(t1382, t1414, t1429)\n",
+ " # t1416 = prims.convert_element_type(t1382, dtypes.float32) # t1416: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1415 = prims.convert_element_type(t1414, dtypes.float32) # t1415: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1417 = prims.add(t1415, t1416) # t1417: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1418 = prims.convert_element_type(t1417, dtypes.bfloat16) # t1418: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1420 = prims.mul(t1417, t1417) # t1420: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1421 = prims.sum(t1420, (2,)) # t1421: \"cuda:0 f32[1, 512]\"\n",
+ " # t1422 = prims.broadcast_in_dim(t1421, [1, 512, 1], [0, 1]) # t1422: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1423 = prims.div(t1422, 4096.0) # t1423: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1424 = prims.add(t1423, 1e-05) # t1424: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1425 = prims.rsqrt(t1424) # t1425: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1426 = prims.broadcast_in_dim(t1425, (1, 512, 4096), (0, 1, 2)) # t1426: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1427 = prims.mul(t1417, t1426) # t1427: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1431 = prims.convert_element_type(t1429, dtypes.float32) # t1431: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1432 = prims.mul(t1427, t1431) # t1432: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1433 = prims.convert_element_type(t1432, dtypes.bfloat16) # t1433: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1434 = torch.nn.functional.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1434 = ltorch.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1434 = prims.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1435 = torch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1435 = ltorch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1435 = prims.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1434\n",
+ " t1436 = torch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1436 = ltorch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1436 = prims.transpose(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1435\n",
+ " (t1437, t1438, t1439) = torch.split(t1436, (1, 1, 1), 2)\n",
+ " # (t1437, t1438, t1439) = ltorch.split(t1436, (1, 1, 1), 2)\n",
+ " # t1437 = prims.slice_prim(t1436, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1437: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1438 = prims.slice_prim(t1436, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1438: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1439 = prims.slice_prim(t1436, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1439: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1436\n",
+ " t1440 = torch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1440 = ltorch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1440 = prims.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1437\n",
+ " t1441 = torch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1441 = ltorch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1441 = prims.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1438\n",
+ " t1442 = torch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1442 = ltorch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1442 = prims.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1439\n",
+ " t1443 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1443: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1458 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1458: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1473 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1473: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1440\n",
+ " t1475 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1475: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1441\n",
+ " t1444 = torch_slice_prim_impl(t1443, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1444: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1445 = torch_slice_prim_impl(t1443, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1445: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1459 = torch_slice_prim_impl(t1458, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1459: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1460 = torch_slice_prim_impl(t1458, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1460: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1448, t1463] = nvFusion61(t1443, t1445, t1458, t1460)\n",
+ " # t1446 = prims.convert_element_type(t1445, dtypes.float32) # t1446: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1447 = prims.neg(t1446) # t1447: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1448 = prims.convert_element_type(t1447, dtypes.bfloat16) # t1448: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1461 = prims.convert_element_type(t1460, dtypes.float32) # t1461: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1462 = prims.neg(t1461) # t1462: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1463 = prims.convert_element_type(t1462, dtypes.bfloat16) # t1463: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1445, t1460\n",
+ " t1464 = torch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1464 = ltorch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1464 = prims.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1463, t1459\n",
+ " t1449 = torch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1449 = ltorch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1449 = prims.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1448, t1444\n",
+ " [t1457, t1472] = nvFusion62(t1443, t1449, t1458, t1464, t154, t157)\n",
+ " # t1451 = prims.convert_element_type(t1443, dtypes.float32) # t1451: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1466 = prims.convert_element_type(t1458, dtypes.float32) # t1466: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1467 = prims.mul(t1466, t154) # t1467: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1469 = prims.convert_element_type(t1464, dtypes.float32) # t1469: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1470 = prims.mul(t1469, t157) # t1470: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1471 = prims.add(t1467, t1470) # t1471: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1472 = prims.convert_element_type(t1471, dtypes.bfloat16) # t1472: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1452 = prims.mul(t1451, t154) # t1452: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1454 = prims.convert_element_type(t1449, dtypes.float32) # t1454: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1455 = prims.mul(t1454, t157) # t1455: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1456 = prims.add(t1452, t1455) # t1456: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1457 = prims.convert_element_type(t1456, dtypes.bfloat16) # t1457: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1443, t1449, t1458, t1464\n",
+ " t1476 = torch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1476 = ltorch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1476 = prims.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1472, t1475\n",
+ " t1474 = torch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1474 = ltorch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1474 = prims.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1457, t1473\n",
+ " (t1477, t1478, t1479, t1480, _, _, t1481, t1482, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1474, t1476, t1442, 0.0, True, scale=0.08838834764831843)\n",
+ " t1484 = torch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1484 = ltorch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1484 = prims.transpose(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1485 = torch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1485 = ltorch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1485 = prims.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1484\n",
+ " t1486 = torch.nn.functional.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1486 = ltorch.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1486 = prims.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1490, t1497, t1505] = nvFusion63(t1418, t1486, t1501)\n",
+ " # t1488 = prims.convert_element_type(t1418, dtypes.float32) # t1488: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1487 = prims.convert_element_type(t1486, dtypes.float32) # t1487: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1489 = prims.add(t1487, t1488) # t1489: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1490 = prims.convert_element_type(t1489, dtypes.bfloat16) # t1490: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1492 = prims.mul(t1489, t1489) # t1492: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1493 = prims.sum(t1492, (2,)) # t1493: \"cuda:0 f32[1, 512]\"\n",
+ " # t1494 = prims.broadcast_in_dim(t1493, [1, 512, 1], [0, 1]) # t1494: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1495 = prims.div(t1494, 4096.0) # t1495: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1496 = prims.add(t1495, 1e-05) # t1496: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1497 = prims.rsqrt(t1496) # t1497: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1498 = prims.broadcast_in_dim(t1497, (1, 512, 4096), (0, 1, 2)) # t1498: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1499 = prims.mul(t1489, t1498) # t1499: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1503 = prims.convert_element_type(t1501, dtypes.float32) # t1503: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1504 = prims.mul(t1499, t1503) # t1504: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1505 = prims.convert_element_type(t1504, dtypes.bfloat16) # t1505: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1506 = torch.nn.functional.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1506 = ltorch.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1506 = prims.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1507 = torch.nn.functional.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1507 = ltorch.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1507 = prims.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1521] = nvFusion64(t1506, t1507)\n",
+ " # t1508 = prims.convert_element_type(t1506, dtypes.float32) # t1508: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1509 = prims.neg(t1508) # t1509: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1510 = prims.exp(t1509) # t1510: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1511 = prims.add(1.0, t1510) # t1511: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1512 = prims.reciprocal(t1511) # t1512: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1516 = prims.mul(t1508, t1512) # t1516: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1519 = prims.convert_element_type(t1507, dtypes.float32) # t1519: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1520 = prims.mul(t1516, t1519) # t1520: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1521 = prims.convert_element_type(t1520, dtypes.bfloat16) # t1521: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1522 = torch.nn.functional.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1522 = ltorch.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1522 = prims.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1526, t1533, t1541] = nvFusion65(t1490, t1522, t1537)\n",
+ " # t1524 = prims.convert_element_type(t1490, dtypes.float32) # t1524: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1523 = prims.convert_element_type(t1522, dtypes.float32) # t1523: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1525 = prims.add(t1523, t1524) # t1525: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1526 = prims.convert_element_type(t1525, dtypes.bfloat16) # t1526: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1528 = prims.mul(t1525, t1525) # t1528: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1529 = prims.sum(t1528, (2,)) # t1529: \"cuda:0 f32[1, 512]\"\n",
+ " # t1530 = prims.broadcast_in_dim(t1529, [1, 512, 1], [0, 1]) # t1530: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1531 = prims.div(t1530, 4096.0) # t1531: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1532 = prims.add(t1531, 1e-05) # t1532: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1533 = prims.rsqrt(t1532) # t1533: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1534 = prims.broadcast_in_dim(t1533, (1, 512, 4096), (0, 1, 2)) # t1534: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1535 = prims.mul(t1525, t1534) # t1535: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1539 = prims.convert_element_type(t1537, dtypes.float32) # t1539: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1540 = prims.mul(t1535, t1539) # t1540: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1541 = prims.convert_element_type(t1540, dtypes.bfloat16) # t1541: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1542 = torch.nn.functional.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1542 = ltorch.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1542 = prims.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1543 = torch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1543 = ltorch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1543 = prims.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1542\n",
+ " t1544 = torch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1544 = ltorch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1544 = prims.transpose(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1543\n",
+ " (t1545, t1546, t1547) = torch.split(t1544, (1, 1, 1), 2)\n",
+ " # (t1545, t1546, t1547) = ltorch.split(t1544, (1, 1, 1), 2)\n",
+ " # t1545 = prims.slice_prim(t1544, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1545: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1546 = prims.slice_prim(t1544, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1546: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1547 = prims.slice_prim(t1544, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1547: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1544\n",
+ " t1548 = torch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1548 = ltorch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1548 = prims.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1545\n",
+ " t1549 = torch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1549 = ltorch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1549 = prims.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1546\n",
+ " t1550 = torch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1550 = ltorch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1550 = prims.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1547\n",
+ " t1551 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1551: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1566 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1566: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1581 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1581: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1548\n",
+ " t1583 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1583: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1549\n",
+ " t1552 = torch_slice_prim_impl(t1551, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1552: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1553 = torch_slice_prim_impl(t1551, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1553: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1567 = torch_slice_prim_impl(t1566, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1567: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1568 = torch_slice_prim_impl(t1566, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1568: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1556, t1571] = nvFusion66(t1551, t1553, t1566, t1568)\n",
+ " # t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1555 = prims.neg(t1554) # t1555: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1556 = prims.convert_element_type(t1555, dtypes.bfloat16) # t1556: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1569 = prims.convert_element_type(t1568, dtypes.float32) # t1569: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1570 = prims.neg(t1569) # t1570: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1571 = prims.convert_element_type(t1570, dtypes.bfloat16) # t1571: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1553, t1568\n",
+ " t1572 = torch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1572 = ltorch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1572 = prims.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1571, t1567\n",
+ " t1557 = torch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1557 = ltorch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1557 = prims.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1556, t1552\n",
+ " [t1565, t1580] = nvFusion67(t154, t1551, t1557, t1566, t157, t1572)\n",
+ " # t1559 = prims.convert_element_type(t1551, dtypes.float32) # t1559: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1574 = prims.convert_element_type(t1566, dtypes.float32) # t1574: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1575 = prims.mul(t1574, t154) # t1575: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1577 = prims.convert_element_type(t1572, dtypes.float32) # t1577: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1578 = prims.mul(t1577, t157) # t1578: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1579 = prims.add(t1575, t1578) # t1579: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1580 = prims.convert_element_type(t1579, dtypes.bfloat16) # t1580: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1560 = prims.mul(t1559, t154) # t1560: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1562 = prims.convert_element_type(t1557, dtypes.float32) # t1562: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1563 = prims.mul(t1562, t157) # t1563: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1564 = prims.add(t1560, t1563) # t1564: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1551, t1557, t1566, t1572\n",
+ " t1584 = torch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1584 = ltorch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1584 = prims.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1580, t1583\n",
+ " t1582 = torch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1582 = ltorch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1582 = prims.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1565, t1581\n",
+ " (t1585, t1586, t1587, t1588, _, _, t1589, t1590, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1582, t1584, t1550, 0.0, True, scale=0.08838834764831843)\n",
+ " t1592 = torch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1592 = ltorch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1592 = prims.transpose(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1593 = torch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1593 = ltorch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1592\n",
+ " t1594 = torch.nn.functional.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1594 = ltorch.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1594 = prims.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1598, t1605, t1613] = nvFusion68(t1526, t1594, t1609)\n",
+ " # t1596 = prims.convert_element_type(t1526, dtypes.float32) # t1596: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1595 = prims.convert_element_type(t1594, dtypes.float32) # t1595: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1597 = prims.add(t1595, t1596) # t1597: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1598 = prims.convert_element_type(t1597, dtypes.bfloat16) # t1598: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1600 = prims.mul(t1597, t1597) # t1600: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1601 = prims.sum(t1600, (2,)) # t1601: \"cuda:0 f32[1, 512]\"\n",
+ " # t1602 = prims.broadcast_in_dim(t1601, [1, 512, 1], [0, 1]) # t1602: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1603 = prims.div(t1602, 4096.0) # t1603: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1604 = prims.add(t1603, 1e-05) # t1604: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1605 = prims.rsqrt(t1604) # t1605: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1606 = prims.broadcast_in_dim(t1605, (1, 512, 4096), (0, 1, 2)) # t1606: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1607 = prims.mul(t1597, t1606) # t1607: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1611 = prims.convert_element_type(t1609, dtypes.float32) # t1611: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1612 = prims.mul(t1607, t1611) # t1612: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1614 = torch.nn.functional.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1614 = ltorch.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1614 = prims.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1615 = torch.nn.functional.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1615 = ltorch.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1615 = prims.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1629] = nvFusion69(t1614, t1615)\n",
+ " # t1616 = prims.convert_element_type(t1614, dtypes.float32) # t1616: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1617 = prims.neg(t1616) # t1617: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1618 = prims.exp(t1617) # t1618: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1619 = prims.add(1.0, t1618) # t1619: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1620 = prims.reciprocal(t1619) # t1620: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1624 = prims.mul(t1616, t1620) # t1624: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1627 = prims.convert_element_type(t1615, dtypes.float32) # t1627: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1628 = prims.mul(t1624, t1627) # t1628: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1629 = prims.convert_element_type(t1628, dtypes.bfloat16) # t1629: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1630 = torch.nn.functional.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1630 = ltorch.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1630 = prims.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1634, t1641, t1649] = nvFusion70(t1598, t1630, t1645)\n",
+ " # t1632 = prims.convert_element_type(t1598, dtypes.float32) # t1632: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1631 = prims.convert_element_type(t1630, dtypes.float32) # t1631: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1633 = prims.add(t1631, t1632) # t1633: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1634 = prims.convert_element_type(t1633, dtypes.bfloat16) # t1634: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1636 = prims.mul(t1633, t1633) # t1636: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1637 = prims.sum(t1636, (2,)) # t1637: \"cuda:0 f32[1, 512]\"\n",
+ " # t1638 = prims.broadcast_in_dim(t1637, [1, 512, 1], [0, 1]) # t1638: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1639 = prims.div(t1638, 4096.0) # t1639: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1640 = prims.add(t1639, 1e-05) # t1640: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1641 = prims.rsqrt(t1640) # t1641: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1642 = prims.broadcast_in_dim(t1641, (1, 512, 4096), (0, 1, 2)) # t1642: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1643 = prims.mul(t1633, t1642) # t1643: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1647 = prims.convert_element_type(t1645, dtypes.float32) # t1647: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1648 = prims.mul(t1643, t1647) # t1648: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1649 = prims.convert_element_type(t1648, dtypes.bfloat16) # t1649: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1650 = torch.nn.functional.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1650 = ltorch.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1650 = prims.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1651 = torch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1651 = ltorch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1651 = prims.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1650\n",
+ " t1652 = torch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1652 = ltorch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1652 = prims.transpose(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1651\n",
+ " (t1653, t1654, t1655) = torch.split(t1652, (1, 1, 1), 2)\n",
+ " # (t1653, t1654, t1655) = ltorch.split(t1652, (1, 1, 1), 2)\n",
+ " # t1653 = prims.slice_prim(t1652, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1653: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1654 = prims.slice_prim(t1652, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1654: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1655 = prims.slice_prim(t1652, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1655: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1652\n",
+ " t1656 = torch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1656 = ltorch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1656 = prims.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1653\n",
+ " t1657 = torch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1657 = ltorch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1657 = prims.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1654\n",
+ " t1658 = torch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1658 = ltorch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1658 = prims.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1655\n",
+ " t1689 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1689: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " t1691 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1691: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " t1659 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1659: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1656\n",
+ " t1674 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1674: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1657\n",
+ " t1660 = torch_slice_prim_impl(t1659, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1660: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1661 = torch_slice_prim_impl(t1659, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1661: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1675: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1676: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1664, t1679] = nvFusion71(t1659, t1661, t1674, t1676)\n",
+ " # t1662 = prims.convert_element_type(t1661, dtypes.float32) # t1662: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1663 = prims.neg(t1662) # t1663: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1664 = prims.convert_element_type(t1663, dtypes.bfloat16) # t1664: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1678 = prims.neg(t1677) # t1678: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1679 = prims.convert_element_type(t1678, dtypes.bfloat16) # t1679: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1661, t1676\n",
+ " t1680 = torch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1680 = ltorch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1680 = prims.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1679, t1675\n",
+ " t1665 = torch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1665 = ltorch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1665 = prims.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1664, t1660\n",
+ " [t1673, t1688] = nvFusion72(t154, t157, t1659, t1665, t1674, t1680)\n",
+ " # t1667 = prims.convert_element_type(t1659, dtypes.float32) # t1667: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1682 = prims.convert_element_type(t1674, dtypes.float32) # t1682: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1683 = prims.mul(t1682, t154) # t1683: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1685 = prims.convert_element_type(t1680, dtypes.float32) # t1685: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1686 = prims.mul(t1685, t157) # t1686: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1687 = prims.add(t1683, t1686) # t1687: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1688 = prims.convert_element_type(t1687, dtypes.bfloat16) # t1688: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1668 = prims.mul(t1667, t154) # t1668: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1670 = prims.convert_element_type(t1665, dtypes.float32) # t1670: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1671 = prims.mul(t1670, t157) # t1671: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1672 = prims.add(t1668, t1671) # t1672: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1673 = prims.convert_element_type(t1672, dtypes.bfloat16) # t1673: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1659, t1665, t1674, t1680\n",
+ " t1692 = torch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1692 = ltorch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1692 = prims.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1688, t1691\n",
+ " t1690 = torch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1690 = ltorch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1690 = prims.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1673, t1689\n",
+ " (t1693, t1694, t1695, t1696, _, _, t1697, t1698, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1690, t1692, t1658, 0.0, True, scale=0.08838834764831843)\n",
+ " t1700 = torch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1700 = ltorch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1700 = prims.transpose(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1701 = torch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1701 = ltorch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1701 = prims.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1700\n",
+ " t1702 = torch.nn.functional.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1702 = ltorch.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1702 = prims.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1706, t1713, t1721] = nvFusion73(t1634, t1702, t1717)\n",
+ " # t1704 = prims.convert_element_type(t1634, dtypes.float32) # t1704: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1703 = prims.convert_element_type(t1702, dtypes.float32) # t1703: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1705 = prims.add(t1703, t1704) # t1705: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1708 = prims.mul(t1705, t1705) # t1708: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1709 = prims.sum(t1708, (2,)) # t1709: \"cuda:0 f32[1, 512]\"\n",
+ " # t1710 = prims.broadcast_in_dim(t1709, [1, 512, 1], [0, 1]) # t1710: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1711 = prims.div(t1710, 4096.0) # t1711: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1712 = prims.add(t1711, 1e-05) # t1712: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1713 = prims.rsqrt(t1712) # t1713: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1714 = prims.broadcast_in_dim(t1713, (1, 512, 4096), (0, 1, 2)) # t1714: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1715 = prims.mul(t1705, t1714) # t1715: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1719 = prims.convert_element_type(t1717, dtypes.float32) # t1719: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1720 = prims.mul(t1715, t1719) # t1720: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1721 = prims.convert_element_type(t1720, dtypes.bfloat16) # t1721: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1722 = torch.nn.functional.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1722 = ltorch.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1722 = prims.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1723 = torch.nn.functional.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1723 = ltorch.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1723 = prims.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1737] = nvFusion74(t1722, t1723)\n",
+ " # t1724 = prims.convert_element_type(t1722, dtypes.float32) # t1724: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1725 = prims.neg(t1724) # t1725: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1726 = prims.exp(t1725) # t1726: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1727 = prims.add(1.0, t1726) # t1727: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1728 = prims.reciprocal(t1727) # t1728: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1732 = prims.mul(t1724, t1728) # t1732: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1735 = prims.convert_element_type(t1723, dtypes.float32) # t1735: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1736 = prims.mul(t1732, t1735) # t1736: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1738 = torch.nn.functional.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1738 = ltorch.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1738 = prims.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1742, t1749, t1757] = nvFusion75(t1706, t1738, t1753)\n",
+ " # t1740 = prims.convert_element_type(t1706, dtypes.float32) # t1740: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1739 = prims.convert_element_type(t1738, dtypes.float32) # t1739: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1741 = prims.add(t1739, t1740) # t1741: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1742 = prims.convert_element_type(t1741, dtypes.bfloat16) # t1742: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1744 = prims.mul(t1741, t1741) # t1744: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1745 = prims.sum(t1744, (2,)) # t1745: \"cuda:0 f32[1, 512]\"\n",
+ " # t1746 = prims.broadcast_in_dim(t1745, [1, 512, 1], [0, 1]) # t1746: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1747 = prims.div(t1746, 4096.0) # t1747: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1748 = prims.add(t1747, 1e-05) # t1748: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1749 = prims.rsqrt(t1748) # t1749: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1750 = prims.broadcast_in_dim(t1749, (1, 512, 4096), (0, 1, 2)) # t1750: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1751 = prims.mul(t1741, t1750) # t1751: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1755 = prims.convert_element_type(t1753, dtypes.float32) # t1755: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1756 = prims.mul(t1751, t1755) # t1756: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1757 = prims.convert_element_type(t1756, dtypes.bfloat16) # t1757: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1758 = torch.nn.functional.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1758 = ltorch.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " # t1758 = prims.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n",
+ " t1759 = torch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1759 = ltorch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " # t1759 = prims.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n",
+ " del t1758\n",
+ " t1760 = torch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1760 = ltorch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " # t1760 = prims.transpose(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n",
+ " del t1759\n",
+ " (t1761, t1762, t1763) = torch.split(t1760, (1, 1, 1), 2)\n",
+ " # (t1761, t1762, t1763) = ltorch.split(t1760, (1, 1, 1), 2)\n",
+ " # t1761 = prims.slice_prim(t1760, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1761: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1762 = prims.slice_prim(t1760, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1762: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " # t1763 = prims.slice_prim(t1760, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1763: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n",
+ " del t1760\n",
+ " t1764 = torch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1764 = ltorch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1764 = prims.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1761\n",
+ " t1765 = torch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1765 = ltorch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1765 = prims.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1762\n",
+ " t1766 = torch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1766 = ltorch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1766 = prims.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1763\n",
+ " t1767 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1767: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1782 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1782: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " t1797 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1797: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1764\n",
+ " t1799 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1799: \"cuda:0 bf16[1, 32, 512, 0]\"\n",
+ " del t1765\n",
+ " t1768 = torch_slice_prim_impl(t1767, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1768: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1769 = torch_slice_prim_impl(t1767, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1769: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1783 = torch_slice_prim_impl(t1782, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1783: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " t1784 = torch_slice_prim_impl(t1782, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1784: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " [t1772, t1787] = nvFusion76(t1767, t1769, t1782, t1784)\n",
+ " # t1770 = prims.convert_element_type(t1769, dtypes.float32) # t1770: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1771 = prims.neg(t1770) # t1771: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1772 = prims.convert_element_type(t1771, dtypes.bfloat16) # t1772: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " # t1785 = prims.convert_element_type(t1784, dtypes.float32) # t1785: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1786 = prims.neg(t1785) # t1786: \"cuda:0 f32[1, 32, 512, 64]\"\n",
+ " # t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: \"cuda:0 bf16[1, 32, 512, 64]\"\n",
+ " del t1769, t1784\n",
+ " t1788 = torch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1788 = ltorch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1788 = prims.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1787, t1783\n",
+ " t1773 = torch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1773 = ltorch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1773 = prims.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1772, t1768\n",
+ " [t1781, t1796] = nvFusion77(t154, t157, t1767, t1773, t1782, t1788)\n",
+ " # t1775 = prims.convert_element_type(t1767, dtypes.float32) # t1775: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1790 = prims.convert_element_type(t1782, dtypes.float32) # t1790: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1791 = prims.mul(t1790, t154) # t1791: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1793 = prims.convert_element_type(t1788, dtypes.float32) # t1793: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1794 = prims.mul(t1793, t157) # t1794: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1795 = prims.add(t1791, t1794) # t1795: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1796 = prims.convert_element_type(t1795, dtypes.bfloat16) # t1796: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1776 = prims.mul(t1775, t154) # t1776: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1778 = prims.convert_element_type(t1773, dtypes.float32) # t1778: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1779 = prims.mul(t1778, t157) # t1779: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1780 = prims.add(t1776, t1779) # t1780: \"cuda:0 f32[1, 32, 512, 128]\"\n",
+ " # t1781 = prims.convert_element_type(t1780, dtypes.bfloat16) # t1781: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1767, t1773, t1782, t1788\n",
+ " t1800 = torch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1800 = ltorch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1800 = prims.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1796, t1799\n",
+ " t1798 = torch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1798 = ltorch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " # t1798 = prims.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n",
+ " del t1781, t1797\n",
+ " (t1801, t1802, t1803, t1804, _, _, t1805, t1806, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1798, t1800, t1766, 0.0, True, scale=0.08838834764831843)\n",
+ " t1808 = torch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1808 = ltorch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " # t1808 = prims.transpose(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n",
+ " t1809 = torch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1809 = ltorch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1809 = prims.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " del t1808\n",
+ " t1810 = torch.nn.functional.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1810 = ltorch.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1810 = prims.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1814, t1821, t1829] = nvFusion78(t1742, t1810, t1825)\n",
+ " # t1812 = prims.convert_element_type(t1742, dtypes.float32) # t1812: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1811 = prims.convert_element_type(t1810, dtypes.float32) # t1811: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1813 = prims.add(t1811, t1812) # t1813: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1814 = prims.convert_element_type(t1813, dtypes.bfloat16) # t1814: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1816 = prims.mul(t1813, t1813) # t1816: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1817 = prims.sum(t1816, (2,)) # t1817: \"cuda:0 f32[1, 512]\"\n",
+ " # t1818 = prims.broadcast_in_dim(t1817, [1, 512, 1], [0, 1]) # t1818: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1819 = prims.div(t1818, 4096.0) # t1819: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1820 = prims.add(t1819, 1e-05) # t1820: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1821 = prims.rsqrt(t1820) # t1821: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1822 = prims.broadcast_in_dim(t1821, (1, 512, 4096), (0, 1, 2)) # t1822: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1823 = prims.mul(t1813, t1822) # t1823: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1827 = prims.convert_element_type(t1825, dtypes.float32) # t1827: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1828 = prims.mul(t1823, t1827) # t1828: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1829 = prims.convert_element_type(t1828, dtypes.bfloat16) # t1829: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1831 = torch.nn.functional.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1831 = ltorch.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1831 = prims.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1830 = torch.nn.functional.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1830 = ltorch.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " # t1830 = prims.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " [t1845] = nvFusion79(t1830, t1831)\n",
+ " # t1832 = prims.convert_element_type(t1830, dtypes.float32) # t1832: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1833 = prims.neg(t1832) # t1833: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1834 = prims.exp(t1833) # t1834: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1835 = prims.add(1.0, t1834) # t1835: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1836 = prims.reciprocal(t1835) # t1836: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1840 = prims.mul(t1832, t1836) # t1840: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1843 = prims.convert_element_type(t1831, dtypes.float32) # t1843: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1844 = prims.mul(t1840, t1843) # t1844: \"cuda:0 f32[1, 512, 11008]\"\n",
+ " # t1845 = prims.convert_element_type(t1844, dtypes.bfloat16) # t1845: \"cuda:0 bf16[1, 512, 11008]\"\n",
+ " t1846 = torch.nn.functional.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1846 = ltorch.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " # t1846 = prims.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " [t1857, t1865] = nvFusion80(t1814, t1846, t1861)\n",
+ " # t1848 = prims.convert_element_type(t1814, dtypes.float32) # t1848: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1847 = prims.convert_element_type(t1846, dtypes.float32) # t1847: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1849 = prims.add(t1847, t1848) # t1849: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1852 = prims.mul(t1849, t1849) # t1852: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1853 = prims.sum(t1852, (2,)) # t1853: \"cuda:0 f32[1, 512]\"\n",
+ " # t1854 = prims.broadcast_in_dim(t1853, [1, 512, 1], [0, 1]) # t1854: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1855 = prims.div(t1854, 4096.0) # t1855: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1856 = prims.add(t1855, 1e-05) # t1856: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1857 = prims.rsqrt(t1856) # t1857: \"cuda:0 f32[1, 512, 1]\"\n",
+ " # t1858 = prims.broadcast_in_dim(t1857, (1, 512, 4096), (0, 1, 2)) # t1858: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1859 = prims.mul(t1849, t1858) # t1859: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1863 = prims.convert_element_type(t1861, dtypes.float32) # t1863: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1864 = prims.mul(t1859, t1863) # t1864: \"cuda:0 f32[1, 512, 4096]\"\n",
+ " # t1865 = prims.convert_element_type(t1864, dtypes.bfloat16) # t1865: \"cuda:0 bf16[1, 512, 4096]\"\n",
+ " t1866 = torch.nn.functional.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n",
+ " # t1866 = ltorch.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n",
+ " # t1866 = prims.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n",
+ " return {'output': t1866, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33, t34, t35, t36, t37, t38, t39, t40, t41, t42, t43, t44, t45, t46, t47, t48, t49, t50, t51, t52, t53, t54, t55, t56, t57, t58, t59, t60, t61, t62, t63, t64, t65, t66, t67, t68, t69, t70, t71, t72, t73, t74, t75, t76, t77, t78, t79, t80, t81, t82, t83, t84, t85, t86, t87, t88, t89, t90, t91, t92, t93, t94, t95, t96, t97, t98, t99, t100, t101, t102, t103, t104, t105, t106, t107, t108, t109, t110, t111, t112, t113, t114, t115, t116, t117], 'flat_output': (t1866,)}, ((t0, t10, t100, t1001, t101, t1010, t102, t103, t104, t1042, t1044, t1045, t1046, t1047, t1048, t1049, t105, t1050, t1053, t1054, t1058, t106, t1065, t1069, t107, t1073, t1074, t1075, t108, t1089, t109, t1090, t1094, t11, t110, t1101, t1105, t1109, t111, t1118, t112, t113, t114, t115, t1150, t1152, t1153, t1154, t1155, t1156, t1157, t1158, t116, t1161, t1162, t1166, t1173, t1177, t1181, t1182, t1183, t1197, t1198, t12, t1202, t1209, t1213, t1217, t122, t1226, t1258, t1260, t1261, t1262, t1263, t1264, t1265, t1266, t1269, t1270, t1274, t1281, t1285, t1289, t129, t1290, t1291, t13, t1305, t1306, t1310, t1317, t1321, t1325, t133, t1334, t1366, t1368, t1369, t137, t1370, t1371, t1372, t1373, t1374, t1377, t1378, t1382, t1389, t1393, t1397, t1398, t1399, t14, t1413, t1414, t1418, t1425, t1429, t1433, t1442, t146, t1474, t1476, t1477, t1478, t1479, t1480, t1481, t1482, t1485, t1486, t1490, t1497, t15, t1501, t1505, t1506, t1507, t1521, t1522, t1526, t1533, t1537, t154, t1541, t1550, t157, t1582, t1584, t1585, t1586, t1587, t1588, t1589, t1590, t1593, t1594, t1598, t16, t1605, t1609, t1613, t1614, t1615, t1629, t1630, t1634, t1641, t1645, t1649, t1658, t1690, t1692, t1693, t1694, t1695, t1696, t1697, t1698, t17, t1701, t1702, t1706, t1713, t1717, t1721, t1722, t1723, t1737, t1738, t1742, t1749, t1753, t1757, t1766, t178, t1798, t18, t180, t1800, t1801, t1802, t1803, t1804, t1805, t1806, t1809, t181, t1810, t1814, t182, t1821, t1825, t1829, t183, t1830, t1831, t184, t1845, t1846, t185, t1857, t186, t1861, t1865, t189, t19, t190, t194, t20, t201, t205, t209, t21, t210, t211, t22, t225, t226, t23, t230, t237, t24, t241, t245, t25, t254, t26, t27, t28, t286, t288, t289, t29, t290, t291, t292, t293, t294, t297, t298, t3, t30, t302, t309, t31, t313, t317, t318, t319, t32, t33, t333, t334, t338, t34, t345, t349, t35, t353, t36, t362, t37, t38, t39, t394, t396, t397, t398, t399, t4, t40, t400, t401, t402, t405, t406, t41, t410, t417, t42, t421, t425, t426, t427, t43, t44, t441, t442, t446, t45, t453, t457, t46, t461, t47, t470, t48, t49, t5, t50, t502, t504, t505, t506, t507, t508, t509, t51, t510, t513, t514, t518, t525, t529, t533, t534, t535, t549, t550, t554, t561, t565, t569, t578, t6, t610, t612, t613, t614, t615, t616, t617, t618, t621, t622, t626, t633, t637, t641, t642, t643, t657, t658, t662, t669, t673, t677, t686, t7, t718, t720, t721, t722, t723, t724, t725, t726, t729, t730, t734, t741, t745, t749, t750, t751, t765, t766, t770, t777, t781, t785, t794, t8, t826, t828, t829, t830, t831, t832, t833, t834, t837, t838, t842, t849, t85, t853, t857, t858, t859, t86, t87, t873, t874, t878, t88, t885, t889, t89, t893, t9, t90, t902, t91, t92, t93, t934, t936, t937, t938, t939, t94, t940, t941, t942, t945, t946, t95, t950, t957, t96, t961, t965, t966, t967, t97, t98, t981, t982, t986, t99, t993, t997), (False, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 0.0, 4096.0, 4096.0, 0.08838834764831843, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))"
]
},
"execution_count": 9,
@@ -1014,35 +3161,40 @@
}
],
"source": [
+ "print(actual.grad_fn)\n",
"thunder.last_traces(thunder_model)[-1]"
]
},
{
"cell_type": "markdown",
- "id": "4944f352",
- "metadata": {},
+ "id": "558f2553-37f7-4b58-b7cd-a744155613a8",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "notes"
+ }
+ },
"source": [
"Well, that is quite a bit to look through.\n",
- "But here is a key thing: The function now returns a buch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n",
+ "But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n",
"`thunder.last_backward_traces(thunder_model)[-1]`)."
]
},
{
"cell_type": "code",
"execution_count": 10,
- "id": "4d90df65",
+ "id": "59643398-d6e2-4c32-81bd-145a1198b1f3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "tensor([[[-0.9922, 0.5946, -0.2173, ..., -0.0981, -0.5058, 0.2747],\n",
- " [-1.1552, 0.5770, -0.7432, ..., 0.0688, 0.1238, 0.6786],\n",
- " [-0.7813, 0.6960, 0.1235, ..., -0.4840, 0.1373, 0.6490],\n",
+ "tensor([[[ 0.4160, -0.4668, 1.1016, ..., 0.5430, 1.2656, 0.2891],\n",
+ " [ 0.3320, -0.0557, 1.7891, ..., 1.0703, 1.0078, 1.2266],\n",
+ " [ 0.6836, -0.2871, 0.9531, ..., 0.0806, 0.7070, 0.8477],\n",
" ...,\n",
- " [ 0.3711, 0.1656, 0.3350, ..., -0.0294, 0.3670, 0.5099],\n",
- " [-0.2544, -0.8470, 0.2063, ..., -0.1341, 0.1877, 0.2612],\n",
- " [ 0.3420, -1.1421, 0.9222, ..., 0.5636, 0.1666, 0.6947]]],\n",
+ " [ 0.7695, -0.1260, 0.7266, ..., 0.1118, -0.0238, -1.2656],\n",
+ " [-0.7773, -0.5547, -0.3047, ..., -0.1807, 0.1895, 0.6875],\n",
+ " [ 0.8867, 0.4766, 0.3984, ..., 0.0815, -0.0879, 0.3477]]],\n",
" device='cuda:0', grad_fn=)"
]
},
@@ -1057,10 +3209,10 @@
},
{
"cell_type": "markdown",
- "id": "7dcec40f",
+ "id": "17341d86-d4c9-46bd-ac5e-3a05da1ff72c",
"metadata": {},
"source": [
- "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced."
+ "Let us clean up a bit."
]
},
{
@@ -1068,25 +3220,21 @@
"execution_count": 11,
"id": "6ba7f715",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "maximum deviation grads: 0.00042724609375\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "actual_grads = torch.autograd.grad(actual.sum(), m.parameters())\n",
- "expected_grads = torch.autograd.grad(expected.sum(), m.parameters())\n",
- "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))"
+ "del actual, expected\n",
+ "import gc\n",
+ "gc.collect();"
]
},
{
"cell_type": "markdown",
"id": "0261eb11",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"But is it faster? Yes!"
]
@@ -1094,50 +3242,52 @@
{
"cell_type": "code",
"execution_count": 12,
- "id": "854f29a5",
+ "id": "bccec79b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "154 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
- "150 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
+ "240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
+ "208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
- "import gc\n",
- "gc.collect()\n",
"%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()\n",
"%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "1d31e7f8",
+ "metadata": {},
+ "source": [
+ "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!\n"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 13,
- "id": "eb177aad",
+ "id": "ecad9125-bbf2-42c8-b11c-23eed4a6cd8f",
"metadata": {},
"outputs": [],
"source": [
"del m, thunder_model\n",
"import gc\n",
"gc.collect()\n",
- "torch.cuda.empty_cache()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1d31e7f8",
- "metadata": {},
- "source": [
- "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!"
+ "torch.cuda.empty_cache()\n"
]
},
{
"cell_type": "markdown",
- "id": "d23ebbf5",
- "metadata": {},
+ "id": "49e3273c-99be-4370-9e59-121c00481b4e",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"## Distributed with Thunder\n",
"\n",
@@ -1160,7 +3310,11 @@
"cell_type": "code",
"execution_count": 14,
"id": "18dd3379",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"outputs": [
{
"name": "stdout",
@@ -1172,21 +3326,19 @@
],
"source": [
"%%writefile zero_to_thunder_fsdp_simple_example.py\n",
- "import sys\n",
- "sys.path.insert(0, '..')\n",
"from thunder.tests.lit_gpt_model import GPT, Config\n",
- "\n",
- "import torch\n",
- "import torch.distributed\n",
- "import thunder\n",
- "import thunder.distributed\n",
"import os\n",
+ "import torch, torch.distributed\n",
+ "import thunder, thunder.distributed\n",
"\n",
"# Create Model\n",
"# NOTE: We create the model on CPU.\n",
"device='cpu'\n",
"torch.set_default_dtype(torch.bfloat16)\n",
- "model = GPT.from_name('llama2-like')\n",
+ "cfg = Config.from_name('Llama-2-7b-hf')\n",
+ "cfg.n_layer = 8 # fewer layers\n",
+ "model = GPT(cfg)\n",
+ "\n",
"# Setup for distributed\n",
"torch.distributed.init_process_group(backend='nccl')\n",
"rank = int(os.environ[\"LOCAL_RANK\"])\n",
@@ -1197,13 +3349,19 @@
"# thunder.distributed.fsdp takes care of moving the parameter\n",
"# shard to the correct GPU for the current process.\n",
"model = thunder.jit(thunder.distributed.fsdp(model)) # <---------------------------------------\n",
- "\n",
+ "print(f\"rank {rank} computing\")\n",
"# Run the forward pass.\n",
- "res = model(x)\n",
- "res.sum().backward()\n",
- "\n",
- "res = model(x)\n",
- "res.sum().backward()\n"
+ "for i in range(10):\n",
+ " res = model(x)\n",
+ " res.sum().backward()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "97e8edbf-424d-49a7-8ed6-12cb5e5d65fc",
+ "metadata": {},
+ "source": [
+ "Now we can launch it. Note that you need two GPUs for this to run correctly."
]
},
{
@@ -1211,17 +3369,22 @@
"execution_count": 15,
"id": "2bad9b64",
"metadata": {
- "scrolled": false
+ "scrolled": true,
+ "slideshow": {
+ "slide_type": "skip"
+ }
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] \r\n",
- "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n",
- "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \r\n",
- "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n"
+ "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] \n",
+ "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n",
+ "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
+ "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n",
+ "rank 1 computing\n",
+ "rank 0 computing\n"
]
}
],
@@ -1232,21 +3395,29 @@
{
"cell_type": "markdown",
"id": "9c65e75d",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "skip"
+ }
+ },
"source": [
- "So there. FSDP with just wrapping the model in `fsdp`."
+ "So there. FSDP with just wrapping the model in `fsdp`.\n"
]
},
{
"cell_type": "markdown",
"id": "4a6d7a20",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"## Extending Thunder\n",
"\n",
"But we promised that thunder is extensible. Let's find out what's up with that.\n",
"\n",
- "Specifically, we will incorporate the RMSNorm kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n",
+ "Specifically, we will incorporate the fast rope embedding kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n",
"\n",
"In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one."
]
@@ -1275,91 +3446,94 @@
},
{
"cell_type": "markdown",
- "id": "a63595ab",
- "metadata": {},
+ "id": "2fe3b40b-c6e9-417c-ab7a-32606cee871a",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "skip"
+ }
+ },
"source": [
- "For our base implementation, we take the ccode from [LitGPT's RMSNorm implementation](https://github.com/Lightning-AI/litgpt/blob/7c1574925f973e64c0a53e056b77229bedee1619/lit_gpt/rmsnorm.py)\n",
+ "For our base implementation, we take the code from [LitGPT's implementation](https://github.com/Lightning-AI/litgpt/blob/be6139e1fd4b240d253efd58124457496d23d173/litgpt/model.py#L355-L361)\n",
"\n",
- "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n"
+ "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n",
+ "Because we will demonstrate Thunder's ability to divert functions in the model, we make a version here that will not be diverted."
]
},
{
"cell_type": "code",
"execution_count": 17,
- "id": "247074b3",
- "metadata": {},
+ "id": "3e74436b-d8eb-472b-9d6d-b6412378fde7",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "skip"
+ }
+ },
"outputs": [],
"source": [
- "from thunder import TensorProxy\n",
- "\n",
- "# Taken from LitGPT, who in turn credit:\n",
- "# Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:\n",
- "# https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.\n",
- "\n",
- "def rms_norm_impl(x: torch.Tensor, weight, dim: int, eps: float, add_unit_offset: bool) -> torch.Tensor:\n",
- " dtype = x.dtype\n",
- " x = x.float()\n",
- " # NOTE: the original RMSNorm paper implementation is not equivalent\n",
- " norm_x = torch.mean(x * x, dim=dim, keepdim=True)\n",
- " x_normed = x * torch.rsqrt(norm_x + eps)\n",
- " x_normed = x_normed.to(dtype=dtype)\n",
- " if add_unit_offset:\n",
- " # Gemma model requires a unit offset\n",
- " # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176\n",
- " return x_normed * (1 + weight)\n",
- " return x_normed * weight\n",
- "\n",
- "def rms_norm_meta(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool) -> TensorProxy:\n",
- " return TensorProxy(like=x)\n",
- "\n",
- "rms_norm = my_ex.register_operator('rms_norm', meta=rms_norm_meta, fn=rms_norm_impl)\n"
+ "import lit_gpt\n",
+ "def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n",
+ " head_size = x.size(-1)\n",
+ " x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n",
+ " x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)\n",
+ " rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)\n",
+ " roped = (x * cos) + (rotated * sin)\n",
+ " return roped.to(dtype=x.dtype)"
]
},
{
"cell_type": "markdown",
- "id": "75ad1dbf",
- "metadata": {},
+ "id": "a63595ab",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "skip"
+ }
+ },
"source": [
- "Because evil monkey-patching is a thing for short demos is a thing, let's replace LitGPT's own implementation. For your own model, you might start out with a that in your code directly."
+ "### Registering operators\n",
+ "\n",
+ "Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n",
+ "\n",
+ "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
- "id": "e0bdecd3",
+ "id": "247074b3",
"metadata": {},
"outputs": [],
"source": [
- "import lit_gpt.rmsnorm\n",
- "if not hasattr(lit_gpt.rmsnorm, 'ThunderOrigRMSNorm'):\n",
- " lit_gpt.rmsnorm.ThunderOrigRMSNorm = lit_gpt.rmsnorm.RMSNorm\n",
+ "import torch, thunder\n",
+ "from thunder.tests.lit_gpt_model import GPT\n",
+ "from thunder import TensorProxy\n",
"\n",
- "class ThunderizedRMSNorm(lit_gpt.rmsnorm.ThunderOrigRMSNorm):\n",
- " def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
- " # This isn't the best paradigm. :/\n",
- " if thunder.core.interpreter.is_jitting():\n",
- " return rms_norm(x, self.weight, self.dim, self.eps, self.add_unit_offset)\n",
- " else:\n",
- " return super().forward(x)\n",
+ "def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n",
+ " return lit_gpt.model.apply_rope(x, cos, sin)\n",
+ "\n",
+ "def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n",
+ " return TensorProxy(like=x)\n",
"\n",
- "lit_gpt.rmsnorm.RMSNorm = ThunderizedRMSNorm"
+ "apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n",
+ " replaces=lit_gpt.model.apply_rope)"
]
},
{
"cell_type": "markdown",
"id": "d6b7d056",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
- "We can try our new RMSNorm: "
+ "### Testing our new operator "
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0ebd5dd1",
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [
{
"name": "stdout",
@@ -1377,12 +3551,13 @@
"\n",
"@torch.no_grad()\n",
"@no_autocast()\n",
- "def computation(x, t_weight):\n",
- " # x: \"cuda:0 f32[256, 4096]\" \n",
- " # t_weight: \"cuda:0 f32[4096]\" \n",
- " t7 = rms_norm(x, t_weight, -1, 1e-06, False) # t7: \"cuda:0 f32[256, 4096]\"\n",
- " del x, t_weight\n",
- " return t7"
+ "def computation(x, t_1_cos, t_1_sin):\n",
+ " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n",
+ " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n",
+ " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n",
+ " t2 = apply_rope(x, t_1_cos, t_1_sin) # t2: \"cuda:0 bf16[2, 128, 4096, 16]\"\n",
+ " del x, t_1_cos, t_1_sin\n",
+ " return t2"
]
},
"execution_count": 19,
@@ -1391,37 +3566,37 @@
}
],
"source": [
- "with torch.device('cuda'):\n",
- " norm_module = ThunderizedRMSNorm(4096)\n",
- " x = torch.randn(256, 4096)\n",
+ "with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n",
"\n",
- "# we're not quite there to handle forward and backward yet, we'll re-enable them below\n",
- "for p in norm_module.parameters(): \n",
- " p.requires_grad_(False)\n",
+ "def test_apply_rope(x, m):\n",
+ " return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n",
"\n",
- "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n",
+ "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n",
"\n",
- "expected = norm_module(x)\n",
- "actual = thunder_norm_module(x)\n",
+ "expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print(\"deviation:\", (expected - actual).abs().max().item())\n",
"\n",
- "print(\"deviation:\", (expected - actual).abs().max().item())\n",
- "\n",
- "thunder.last_traces(thunder_norm_module)[-1]"
+ "thunder.last_traces(thunder_apply_rope)[-1]"
]
},
{
"cell_type": "markdown",
"id": "8c620a38",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
+ "### Optimized kernels\n",
+ "\n",
"But why did we do this? Well, we can now layer a faster implementation on top.\n",
- "For this we take the [unsloth RMSNorm](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rms_layernorm.py) kernels. We the bits that were in the forward and backward of the `autograd.Function` into our implementation functions and define the corresponding metas."
+ "For this we take the [unsloth fast rope embedding](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rope_embedding.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels."
]
},
{
"cell_type": "code",
"execution_count": 20,
- "id": "a7a26f5f",
+ "id": "6e6d0b1e-ba14-43e5-b0d9-27c0e3b46879",
"metadata": {},
"outputs": [],
"source": [
@@ -1457,196 +3632,214 @@
" elif BLOCK_SIZE >= 2048: num_warps = 8\n",
" return BLOCK_SIZE, num_warps\n",
"\n",
+ "@triton.heuristics({\"BACKWARD_PASS\": lambda args: args[\"BACKWARD_PASS\"],})\n",
"@triton.jit\n",
- "def _rms_layernorm_forward(\n",
- " Y, Y_row_stride,\n",
- " X, X_row_stride,\n",
- " W, W_row_stride,\n",
- " r, r_row_stride,\n",
- " n_cols, eps,\n",
- " BLOCK_SIZE : tl.constexpr\n",
+ "def _rope_embedding(\n",
+ " Q, Q_row_stride,\n",
+ " cos, cos_row_stride,\n",
+ " sin, sin_row_stride,\n",
+ " seqlen, head_dim, group_size, n_heads,\n",
+ " BACKWARD_PASS: tl.constexpr,\n",
+ " BLOCK_SIZE : tl.constexpr,\n",
"):\n",
" \"\"\"\n",
- " Fast RMS Layernorm kernel\n",
- " Inspiration from a Triton tutorial:\n",
- " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n",
+ " Calculates the RoPE Embedding quickly\n",
+ " RoPE is Q * cos + rotate_half(Q) * sin\n",
+ " See our blog post for more info\n",
" \"\"\"\n",
- " row_idx = tl.program_id(0)\n",
- " col_offsets = tl.arange(0, BLOCK_SIZE)\n",
- " mask = col_offsets < n_cols\n",
+ " row_position = tl.program_id(0)\n",
+ " group_head_position = tl.program_id(1)\n",
+ " col_offsets = tl.arange(0, BLOCK_SIZE)\n",
+ " half_head_dim = head_dim // 2\n",
+ " mask = col_offsets < half_head_dim\n",
"\n",
- " Y += row_idx * Y_row_stride\n",
- " X += row_idx * X_row_stride\n",
- " r += row_idx * r_row_stride\n",
+ " sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n",
+ " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n",
+ " cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n",
+ " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n",
"\n",
- " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n",
- " W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)\n",
+ " if BACKWARD_PASS:\n",
+ " # See our blog post for more info.\n",
+ " sin1 = -sin1\n",
+ " pass\n",
"\n",
- " row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n",
- " inv_var = tl.math.rsqrt(row_var + eps)\n",
- " tl.store(r, inv_var)\n",
- " normed = X_row * inv_var\n",
- " normed = normed.to(W_row.dtype) # Exact copy from HF\n",
- " output = normed * W_row\n",
- " tl.store(Y + col_offsets, output, mask = mask)\n",
+ " head_start = group_head_position * group_size\n",
+ " head_end = min((head_start + group_size), n_heads)\n",
"\n",
+ " for i in range(head_start, head_end):\n",
+ " offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets\n",
+ " offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim\n",
"\n",
- "@triton.jit\n",
- "def _rms_layernorm_backward(\n",
- " dY, dY_row_stride,\n",
- " X, X_row_stride,\n",
- " W, W_row_stride,\n",
- " r, r_row_stride,\n",
- " dW, dW_row_stride,\n",
- " n_cols, eps,\n",
- " BLOCK_SIZE : tl.constexpr,\n",
- "):\n",
- " \"\"\"\n",
- " Fast RMS Layernorm kernel for the backward pass\n",
- " Inspiration from a Triton tutorial:\n",
- " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n",
- " \"\"\"\n",
- " row_idx = tl.program_id(0)\n",
- " col_offsets = tl.arange(0, BLOCK_SIZE)\n",
- " mask = col_offsets < n_cols\n",
- "\n",
- " dY += row_idx * dY_row_stride\n",
- " X += row_idx * X_row_stride\n",
- " r += row_idx * r_row_stride\n",
- "\n",
- " dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n",
- " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n",
- " W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n",
- "\n",
- " # Get saved row variance\n",
- " inv_var = tl.load(r).to(tl.float32)\n",
- " normed = X_row * inv_var\n",
- "\n",
- " dY_W = dY_row * W_row\n",
- "\n",
- " rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)\n",
- " output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)\n",
- " tl.store(dY + col_offsets, output, mask = mask)\n",
- " \n",
- "def rms_layernorm_forward_impl(X, W, eps):\n",
- " shape = X.shape\n",
- " dim = shape[-1]\n",
- " X = X.view(-1, dim)\n",
- " n_rows, n_cols = X.shape\n",
- " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n",
- "\n",
- " Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = \"cuda\")\n",
- " r = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n",
- "\n",
- " _rms_layernorm_forward[(n_rows,)](\n",
- " Y, Y.stride(0),\n",
- " X, X.stride(0),\n",
- " W, W.stride(0),\n",
- " r, r.stride(0),\n",
- " n_cols, eps,\n",
+ " # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n",
+ " Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n",
+ " Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n",
+ "\n",
+ " tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n",
+ " tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n",
+ " pass\n",
+ "pass\n",
+ "\n",
+ "\n",
+ "def fast_rope_embedding_forward(Q, cos, sin):\n",
+ " Q = Q.transpose(1, 2).clone()\n",
+ " cos, sin = cos.squeeze(), sin.squeeze()\n",
+ " batch, seq_len, n_heads, head_dim = Q.shape\n",
+ " Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n",
+ " n_rows, n_cols = Q.shape\n",
+ " assert(seq_len <= cos.shape[0])\n",
+ "\n",
+ " # [TODO] Changing blocksize to head_dim//2 seems to have\n",
+ " # some concurrency / un-deterministic issues.\n",
+ " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)\n",
+ " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n",
+ " n_groups = triton.cdiv(n_heads, group_size)\n",
+ "\n",
+ " grid = (n_rows, n_groups, )\n",
+ " _rope_embedding[grid](\n",
+ " Q, Q.stride(0),\n",
+ " cos, cos.stride(0),\n",
+ " sin, sin.stride(0),\n",
+ " seq_len, head_dim, group_size, n_heads,\n",
+ " BACKWARD_PASS = False,\n",
" BLOCK_SIZE = BLOCK_SIZE,\n",
" num_warps = num_warps,\n",
" )\n",
- " return Y.view(*shape), (r, BLOCK_SIZE, num_warps)\n",
- "\n",
- "def rms_layernorm_forward_meta(X, W, eps):\n",
- " n_cols = X.shape[-1]\n",
- " n_rows = 1\n",
- " for i in X.shape[:-1]:\n",
- " n_rows *= i\n",
- " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n",
- " Y = TensorProxy(like=X, requires_grad=True)\n",
- " return (Y,\n",
- " (TensorProxy(shape=(n_rows,), device=X.device, dtype=thunder.dtypes.float32, requires_grad=False),\n",
- " BLOCK_SIZE, \n",
- " num_warps,\n",
- " )\n",
- " )\n",
- "\n",
- "def rms_layernorm_backward_impl(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n",
- " shape = dY.shape\n",
- " dim = shape[-1]\n",
- " dY = dY.view(-1, dim)\n",
+ " Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)\n",
+ " return Q, (BLOCK_SIZE, num_warps) \n",
+ "\n",
+ "def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):\n",
+ " dY = dY.transpose(1, 2)\n",
+ " batch, seq_len, n_heads, head_dim = dY.shape\n",
+ " dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n",
+ " # Must be reshape not view\n",
" n_rows, n_cols = dY.shape\n",
- " dW = X\n",
- " dX = dY.clone()\n",
- " _rms_layernorm_backward[(n_rows,)](\n",
- " dX, dX.stride(0),\n",
- " X, X .stride(0),\n",
- " W, W .stride(0),\n",
- " r, r .stride(0),\n",
- " dW, dW.stride(0),\n",
- " n_cols, eps,\n",
+ "\n",
+ " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n",
+ " n_groups = triton.cdiv(n_heads, group_size)\n",
+ "\n",
+ " grid = (n_rows, n_groups, )\n",
+ " _rope_embedding[grid](\n",
+ " dY, dY .stride(0),\n",
+ " cos, cos.stride(0),\n",
+ " sin, sin.stride(0),\n",
+ " seq_len, head_dim, group_size, n_heads,\n",
+ " BACKWARD_PASS = True,\n",
" BLOCK_SIZE = BLOCK_SIZE,\n",
" num_warps = num_warps,\n",
" )\n",
- " dX = dX.view(*shape)\n",
- " return dX\n",
+ " dY = dY.view(batch, seq_len, n_heads, head_dim)\n",
+ " dY = dY.transpose(1, 2) \n",
+ " return dY\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ed1e9be3-d1c9-4c4b-bf14-a025a03687ac",
+ "metadata": {},
+ "source": [
+ "We also define the corresponding meta functions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "d7e6612d-f1fc-497c-9d64-15ef99824086",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def fast_rope_embedding_forward_meta(Q, cos, sin):\n",
+ " batch, n_heads, seq_len, head_dim = Q.shape\n",
+ " n_rows, n_cols = batch*seq_len, n_heads*head_dim \n",
+ " assert(seq_len <= cos.shape[0])\n",
+ "\n",
+ " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)\n",
+ " return TensorProxy(like=Q), (BLOCK_SIZE, num_warps) \n",
"\n",
- "def rms_layernorm_backward_meta(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n",
+ "def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):\n",
" return TensorProxy(like=dY)"
]
},
{
"cell_type": "markdown",
"id": "b70eba5f",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
- "With this, we can just register the additional operators:"
+ "### Register optimized operators\n",
+ "\n",
+ "Just like the `apply_rope` before, we can register operators for the optimized forward and backward."
]
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 22,
"id": "f8f1e77e",
"metadata": {},
"outputs": [],
"source": [
- "unsloth_rms_norm_forward = my_ex.register_operator('unsloth_rms_norm_forward', meta=rms_layernorm_forward_meta, fn=rms_layernorm_forward_impl)\n",
- "unsloth_rms_norm_backward = my_ex.register_operator('unsloth_rms_norm_backward', meta=rms_layernorm_backward_meta, fn=rms_layernorm_backward_impl)"
+ "unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward', \n",
+ " meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)\n",
+ "unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward', \n",
+ " meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)"
]
},
{
"cell_type": "markdown",
"id": "2426263d",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
- "But instead of monkey-patching more, we can now register the kernel as an _implementation_ of the base `rms_norm` primitive defined above. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`rms_norm`) in terms of our new operator - so it has the call signature of the `rms_norm`. Because - like many fast implementations - the unsloth RMS norm does not implement the operator in full generality (to do them justice, they have a variant adding the unit offset, we just didn't copy it over), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs."
+ "### Implementations for operators\n",
+ "\n",
+ "Do we need to divert `apply_rope` again? No!\n",
+ "We can register the specialized kernel as an _implementation_ of our base `apply_rope` operator. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`apply_ropw`) in terms of our new operator - so it has the call signature of the `apply_rope`. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well, actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs."
]
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 23,
"id": "6b5c8320",
"metadata": {},
"outputs": [],
"source": [
- "def rms_norm_to_unsloth(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n",
- " assert dim == -1 and not add_unit_offset\n",
- " res, _ = unsloth_rms_norm_forward(x, weight, eps)\n",
+ "def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n",
+ " assert len(x.shape) == 4\n",
+ " res, *_ = unsloth_apply_rope_forward(x, cos, sin)\n",
" return res\n",
"\n",
- "def rms_norm_to_unsloth_checker(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n",
- " if dim != -1 or add_unit_offset:\n",
+ "def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:\n",
+ " if len(x.shape) != 4:\n",
" return False\n",
- " if weight.requires_grad:\n",
- " return False # the unsloth rms norm backwward only gives the grad w.r.t. x\n",
- " return x.device.devicetype == thunder.devices.DeviceType.CUDA and weight.device.devicetype == thunder.devices.DeviceType.CUDA\n",
+ " return (x.device.devicetype == thunder.devices.DeviceType.CUDA and\n",
+ " cos.device.devicetype == thunder.devices.DeviceType.CUDA and\n",
+ " cos.device.devicetype == thunder.devices.DeviceType.CUDA)\n",
"\n",
- "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker, execution_transform=rms_norm_to_unsloth)\n"
+ "my_ex.register_implementation(apply_rope,\n",
+ " checker=apply_rope_to_unsloth_checker,\n",
+ " execution_transform=apply_rope_to_unsloth)\n"
]
},
{
"cell_type": "markdown",
"id": "eec7c95a",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
- "So let us give that a try! Works great..."
+ "So let us give it a try! Works great..."
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 24,
"id": "965ba1d7",
"metadata": {},
"outputs": [
@@ -1654,7 +3847,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "deviation: 9.5367431640625e-07\n"
+ "deviation: 0.015625\n"
]
},
{
@@ -1666,49 +3859,45 @@
"\n",
"@torch.no_grad()\n",
"@no_autocast()\n",
- "def computation(x, t_weight):\n",
- " # x: \"cuda:0 f32[2048, 4096]\" \n",
- " # t_weight: \"cuda:0 f32[4096]\" \n",
- " (t7, (_, _, _)) = unsloth_rms_norm_forward(x, t_weight, 1e-06)\n",
- " del x, t_weight\n",
- " return t7"
+ "def computation(x, t_1_cos, t_1_sin):\n",
+ " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n",
+ " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n",
+ " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n",
+ " (t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)\n",
+ " del x, t_1_cos, t_1_sin\n",
+ " return t2"
]
},
- "execution_count": 23,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "with torch.device('cuda'):\n",
- " norm_module = ThunderizedRMSNorm(4096)\n",
- "\n",
- "# unfortunately, we meet dragons if we don't do this at this stage\n",
- "for p in norm_module.parameters(): \n",
- " p.requires_grad_(False)\n",
- "\n",
- "thunder_norm_module = thunder.jit(norm_module, executors=[my_ex,]) \n",
- "x = torch.randn(2048, 4096, device=\"cuda\")\n",
- "\n",
- "expected = norm_module(x)\n",
- "actual = thunder_norm_module(x)\n",
+ "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n",
"\n",
+ "expected = test_apply_rope(Q, m)\n",
+ "actual = thunder_apply_rope(Q, m)\n",
"print(\"deviation:\", (expected - actual).abs().max().item())\n",
"\n",
- "thunder.last_traces(thunder_norm_module)[-1]"
+ "thunder.last_traces(thunder_apply_rope)[-1]"
]
},
{
"cell_type": "markdown",
- "id": "0e3e4d85",
- "metadata": {},
+ "id": "69a93d3d-3a88-4297-b330-23a7fff2c4b4",
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
"And this is also automatic when we instantiate a larger llama2-like model:"
]
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 25,
"id": "7fff2522",
"metadata": {},
"outputs": [
@@ -1716,7 +3905,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "deviation: 4.76837158203125e-07\n"
+ "deviation: 5.960464477539062e-07\n"
]
}
],
@@ -1740,34 +3929,37 @@
{
"cell_type": "markdown",
"id": "b538cb40",
- "metadata": {},
+ "metadata": {
+ "slideshow": {
+ "slide_type": "slide"
+ }
+ },
"source": [
- "By peeking into the trace, we can see that it actually used the unsloth RMS kernels:"
+ "By peeking into the trace, we can see that it actually used the unsloth apply rope:"
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 26,
"id": "c260cb25",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "[' (n_1, (_, _, _)) = unsloth_rms_norm_forward(x, t_transformer_h_0_norm_1_weight, 1e-05)',\n",
- " ' (t110, (_, _, _)) = unsloth_rms_norm_forward(t102, t_transformer_h_0_norm_2_weight, 1e-05)',\n",
- " ' (t139, (_, _, _)) = unsloth_rms_norm_forward(t130, t_transformer_h_1_norm_1_weight, 1e-05)',\n",
- " ' (t215, (_, _, _)) = unsloth_rms_norm_forward(t207, t_transformer_h_1_norm_2_weight, 1e-05)',\n",
- " ' (t243, (_, _, _)) = unsloth_rms_norm_forward(t235, t_transformer_ln_f_weight, 1e-05)']"
+ "[' (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',\n",
+ " ' (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',\n",
+ " ' (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',\n",
+ " ' (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']"
]
},
- "execution_count": 25,
+ "execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'rms' in s]"
+ "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'apply_rope' in s]"
]
},
{
@@ -1775,79 +3967,97 @@
"id": "0f6c0780",
"metadata": {},
"source": [
- "But what about the backward?\n",
+ "### But what about the backward?\n",
"\n",
- "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`."
+ "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`. \n"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 27,
"id": "7670a872",
"metadata": {},
"outputs": [],
"source": [
"from thunder.core.transforms import get_grad, put_grads\n",
"\n",
- "def unsloth_rms_norm_grad(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool):\n",
- " res, (r, BLOCK_SIZE, num_warps) = unsloth_rms_norm_forward(x, weight, eps)\n",
+ "def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):\n",
+ " res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)\n",
" grad_res = get_grad(res)\n",
- " grad_x = unsloth_rms_norm_backward(x, weight, r, eps, BLOCK_SIZE, num_warps, grad_res)\n",
+ " grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)\n",
" put_grads((x,), (grad_x,))\n",
" return res\n",
"\n",
- "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker,\n",
- " execution_transform=rms_norm_to_unsloth,\n",
- " grad_transform=unsloth_rms_norm_grad \n",
+ "my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,\n",
+ " execution_transform=apply_rope_to_unsloth,\n",
+ " grad_transform=unsloth_apply_rope_grad \n",
" )\n",
"\n"
]
},
{
- "cell_type": "code",
- "execution_count": 27,
- "id": "d31aced0",
+ "cell_type": "markdown",
+ "id": "219dfaa4-cdef-47de-b60c-7c7c1642cb84",
"metadata": {},
+ "source": [
+ "Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "68226a4a-6ad8-43fb-b92f-c1e8eec6f13e",
+ "metadata": {},
+ "source": [
+ "And let us try our function using the optimized backward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "ccc3ed63-ddc2-4b0e-bcd0-f77d66fefe9f",
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "torch.Size([256, 4096]) torch.Size([256, 4096]) torch.Size([4096]) torch.Size([256]) torch.Size([256, 4096])\n",
- "(4096, 1) (4096, 1) (1,) (1,) (4096, 1)\n",
- "maximum deviation grads: 3.5762786865234375e-07\n"
+ "res deviation: 0.015625\n",
+ "grad deviation: 0.0078125\n"
]
}
],
"source": [
- "with torch.device('cuda'):\n",
- " norm_module = ThunderizedRMSNorm(4096)\n",
- " norm_module.weight.requires_grad_(False)\n",
- " x = torch.randn(256, 4096, requires_grad=True)\n",
+ "Q.requires_grad_()\n",
"\n",
- "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n",
+ "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())\n",
"\n",
- "actual = thunder_norm_module(x)\n",
- "expected = norm_module(x)\n",
- "actual_grads = torch.autograd.grad(actual.sum(), x)\n",
- "expected_grads = torch.autograd.grad(expected.sum(), x)\n",
+ "expected = test_apply_rope(Q, m)\n",
+ "go = torch.ones_like(expected)\n",
+ "gr_expected, = torch.autograd.grad(expected, Q, go)\n",
+ "actual = thunder_apply_rope(Q, m)\n",
+ "gr_actual, = torch.autograd.grad(actual, Q, go)\n",
"\n",
- "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))"
+ "print(\"res deviation:\", (expected - actual).abs().max().item())\n",
+ "print(\"grad deviation:\", (gr_expected - gr_actual).abs().max().item())"
]
},
{
"cell_type": "markdown",
- "id": "be218e9d",
+ "id": "63cb61ee-c791-49d1-ba5c-3fe4b5b9a9d5",
"metadata": {},
"source": [
- "And here is our module having the unsloth backward:"
+ "And with `last_backward_traces` we can check that our module is using the unsloth backward:"
]
},
{
"cell_type": "code",
"execution_count": 29,
- "id": "ac00153b",
- "metadata": {},
+ "id": "cd12ca02-6f06-4d88-b5b7-25c4c27dbc9a",
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [
{
"data": {
@@ -1862,7 +4072,7 @@
" # saved_for_backward: \"Collection\" \n",
" # cotangents: \"Collection\" \n",
" C0, \\\n",
- " C1, \\\n",
+ " _, \\\n",
" = saved_for_backward\n",
" clear_collection(saved_for_backward)\n",
" del saved_for_backward\n",
@@ -1870,19 +4080,14 @@
" = cotangents\n",
" clear_collection(cotangents)\n",
" del cotangents\n",
- " t0, \\\n",
" t1, \\\n",
- " t3, \\\n",
+ " t2, \\\n",
" = C0\n",
" clear_collection(C0)\n",
" del C0\n",
- " f0, \\\n",
- " = C1\n",
- " clear_collection(C1)\n",
- " del C1\n",
- " t2 = unsloth_rms_norm_backward(t0, t1, t3, f0, 4096, 8, t4) # t2: \"cuda:0 f32[256, 4096]\"\n",
- " del t0, t1, t3, f0, t4\n",
- " return (t2, None)"
+ " t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4) # t3: \"cuda:0 bf16[2, 128, 4096, 16]\"\n",
+ " del t1, t2, t4\n",
+ " return (t3, None, None)"
]
},
"execution_count": 29,
@@ -1891,29 +4096,96 @@
}
],
"source": [
- "thunder.last_backward_traces(thunder_norm_module)[-1]"
+ "thunder.last_backward_traces(thunder_apply_rope)[-1]"
]
},
{
"cell_type": "markdown",
- "id": "26ac79f0",
+ "id": "2776d183-0232-495e-aa75-3b90e799c841",
"metadata": {},
"source": [
- "That's it! Do check out our LitGPT studios and the other tutorial notebooks.\n"
+ "### Comparing and exploring optimizations\n",
+ "\n",
+ "It is also straightforward to compare potential optimizations.\n",
+ "\n",
+ "Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the hardware used, software environment, or memory layout of the operands."
]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "586cdd30",
+ "execution_count": 30,
+ "id": "a5e0ce05",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "eager\n",
+ "3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
+ "thunder + unsloth\n",
+ "6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
+ "thunder default (nvfuser)\n",
+ "1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
+ ]
+ }
+ ],
+ "source": [
+ "def test_apply_rope_copy(x, m):\n",
+ " return apply_rope_copy(x, m.cos, m.sin)\n",
+ "\n",
+ "test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n",
+ "test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)\n",
+ "y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)\n",
+ "y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)\n",
+ "y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)\n",
+ "\n",
+ "print(\"eager\")\n",
+ "%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n",
+ "print(\"thunder + unsloth\")\n",
+ "%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n",
+ "print(\"thunder default (nvfuser)\")\n",
+ "%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "08b8454f-c725-470c-92a5-56b2206af0e8",
"metadata": {},
- "outputs": [],
- "source": []
+ "source": [
+ "That's it!\n",
+ "\n",
+ "## Conclusion\n",
+ "\n",
+ "To wrap up, we hope you got a taste of\n",
+ "\n",
+ "- Getting things going with Thunder:\n",
+ "\n",
+ " - Applying Thunder through `thunder.jit` and\n",
+ " - using FSDP by just wrapping the model in `thunder.distributed.fsdp` before compilation.\n",
+ "\n",
+ "- See what's going on inspecting traces:\n",
+ "\n",
+ " - `thunder.last_traces` for the forward traces,\n",
+ " - `thunder.last_backward_traces` for the backward,\n",
+ " \n",
+ "- Extending Thunder:\n",
+ "\n",
+ " - registering operators with the `OperatorExecutor`,\n",
+ " - defining implementations with custom forward and backward to include optimized kernels.\n",
+ "\n",
+ "Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the [Thunder section of the Lightning forums](https://lightning.ai/forums/c/thunder) or in the `#thunder` channel on the [PyTorch-Lightning slack](https://pytorch-lightning.slack.com/). \n",
+ "\n",
+ "Do check out our LitGPT studios and the other tutorial notebooks.\n"
+ ]
}
],
"metadata": {
+ "celltoolbar": "Slideshow",
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "Python 3",
"language": "python",
"name": "python3"
},
@@ -1927,7 +4199,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.7"
+ "version": "3.10.10"
}
},
"nbformat": 4,
diff --git a/requirements/docs.txt b/requirements/docs.txt
index f12c617045..14615a08c4 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -1,7 +1,7 @@
sphinx ==5.3.0
myst-parser ==1.0.0
nbsphinx ==0.9.3
-ipython[all] ==8.22.1
+ipython[all] ==8.22.2
pandoc ==2.3
docutils >=0.16
sphinxcontrib-fulltoc ==1.2.0
diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt
index d2f9d92e56..47a14902ca 100644
--- a/requirements/notebooks.txt
+++ b/requirements/notebooks.txt
@@ -1 +1 @@
-ipython[all] ==8.22.1
+ipython[all] ==8.22.2
diff --git a/requirements/test.txt b/requirements/test.txt
index c6d95336ef..a1402bd69b 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -8,7 +8,7 @@ pytest-timestamper ==0.0.9
graphviz ==0.20.1
fdm ==0.4.1
expecttest ==0.2.1 # for test_ddp.py
-hypothesis ==6.98.15 # for test_ddp.py
+hypothesis ==6.99.10 # for test_ddp.py
numpy # for test_ops.py
einops # for test_einops.py
lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5
diff --git a/setup.py b/setup.py
index 566f2bbc40..b0ee14e897 100755
--- a/setup.py
+++ b/setup.py
@@ -21,7 +21,7 @@ def _load_py_module(fname, pkg="thunder"):
def _load_requirements(path_dir: str, file_name: str = "requirements.txt") -> list:
reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines())
- return list(map(str, reqs))
+ return [r for r in list(map(str, reqs)) if "@" not in r]
def _prepare_extras(
@@ -43,6 +43,19 @@ def _prepare_extras(
return extras
+def _load_readme_description(path_dir: str, homepage: str, version: str) -> str:
+ """Load readme as decribtion."""
+ path_readme = os.path.join(path_dir, "README.md")
+ with open(path_readme, encoding="utf-8") as fp:
+ text = fp.read()
+ # https://github.com/Lightning-AI/lightning-thunder/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png
+ github_source_url = os.path.join(homepage, "raw", version)
+ # replace relative repository path to absolute link to the release
+ # do not replace all "docs" as in the readme we replace some other sources with particular path to docs
+ text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}")
+ return text
+
+
about = _load_py_module("__about__.py")
# https://packaging.python.org/discussions/install-requires-vs-requirements /
@@ -58,7 +71,9 @@ def _prepare_extras(
download_url="https://github.com/Lightning-AI/lightning-thunder",
license=about.__license__,
packages=find_packages(exclude=["thunder/tests", "docs"]),
- long_description=about.__long_doc__,
+ long_description=_load_readme_description(
+ path_dir=_PATH_ROOT, homepage=about.__homepage__, version=about.__version__
+ ),
long_description_content_type="text/markdown",
include_package_data=True,
zip_safe=False,
diff --git a/thunder/__about__.py b/thunder/__about__.py
index d9fb64ba1d..15e838ef4f 100644
--- a/thunder/__about__.py
+++ b/thunder/__about__.py
@@ -1,21 +1,17 @@
-__version__ = "0.0.0dev"
+__version__ = "0.1.0"
__author__ = "Lightning-AI et al"
__author_email__ = "community@lightning.ai"
__license__ = "Apache 2.0"
__copyright__ = f"2024 {__author__}"
__homepage__ = "https://github.com/Lightning-AI/lightning-thunder"
__docs__ = "Lightning Thunder project."
-# todo: consider loading Readme here...
-__long_doc__ = """
-Lightning Thunder is a deep learning compiler for PyTorch.
-"""
+
__all__ = [
"__author__",
"__author_email__",
"__copyright__",
"__docs__",
- "__long_doc__",
"__homepage__",
"__license__",
"__version__",
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index 40db2db5de..9120584989 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -9,13 +9,9 @@
import thunder
from thunder.tests.lit_gpt_model import Config, GPT, Block
-try:
- from lightning.fabric.utilities.throughput import measure_flops
+from lightning.fabric.utilities.throughput import measure_flops
+from lightning.fabric.utilities import Throughput
- # from lightning.fabric.utilities import Throughput
- LIGHTNING_AVAILABLE = True
-except:
- LIGHTNING_AVAILABLE = False
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
@@ -108,16 +104,21 @@ def __init__(
if n_layers is not None:
self.config.n_layer = n_layers
- # Initialize the model and the optimizer
+ # Initialize the model
+ t0 = time.perf_counter()
+ print(f"Loading model with {self.config.__dict__}")
self.model = self.init_model()
- self.optimizer = configure_optimizers(
- self.model, weight_decay, learning_rate, (beta1, beta2), device_type="cuda"
- )
+ print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
# Setup the distributed algorithm choices
if self.distributed_mode != "none":
self.model = self.setup_distributed()
+ # Initialize the optimizer after the model is sharded if using FSDP
+ self.optimizer = configure_optimizers(
+ self.model, weight_decay, learning_rate, (beta1, beta2), device_type="cuda"
+ )
+
# Compile the model
if self.compile not in ["eager", None]:
self.model = self.setup_compile()
@@ -136,13 +137,10 @@ def __init__(
}
def init_model(self):
- print(f"Loading model with {self.config.__dict__}")
- t0 = time.perf_counter()
- with self.device:
+ init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device
+ with init_device:
model = GPT(self.config)
- model.to(dtype=torch.bfloat16)
- print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
-
+ model.to(dtype=torch.bfloat16)
return model
def setup_distributed(self):
@@ -240,7 +238,7 @@ def pad_collate(batch):
y_padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=-1)
return x_padded, y_padded
- train_data = DummyDataset(self.model.max_seq_length, self.dynamic)
+ train_data = DummyDataset(self.config.block_size, self.dynamic)
train_dataloader = DataLoader(
train_data, batch_size=self.micro_batch_size, num_workers=2, collate_fn=pad_collate
)
@@ -248,24 +246,30 @@ def pad_collate(batch):
return train_dataloader
def calculate_model_flops(self):
- input_ids, targets = next(self.train_data_iter)
- input_ids = input_ids.to(device=self.device)
- targets = targets.to(device=self.device)
+ meta = torch.device("meta")
+ device = self.device
+ self.device = meta
+
+ # calculate flops on a meta-device model because we only care about the shapes and
+ # because the flops calculator installs hooks on the model
+ meta_model = self.init_model()
- model_fwd = lambda: self.model(input_ids)
+ x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta)
+ model_fwd = lambda: meta_model(x)
model_loss = lambda y: torch.nn.functional.cross_entropy(
- y.reshape(-1, y.size(-1)), targets.reshape(-1), ignore_index=-1
+ y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1
)
- if LIGHTNING_AVAILABLE:
- self.perf_metrics["model_flops"] = measure_flops(self.model, model_fwd, model_loss) / 1e12
+ self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss)
+
+ self.device = device
def train(self):
t0 = None
- # if global_rank in [0, None]:
- # #Calculate the model FLOPs
- # self.calculate_model_flops()
- # Setup Perf Collection
- # self.throughput = Throughput(window_size=10, world_size=world_size)
+ if global_rank in [0, None]:
+ # Calculate the model FLOPs
+ self.calculate_model_flops()
+ # Setup throughput Collection
+ self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size)
if "transformerengine" in self.compile:
import transformer_engine.pytorch as te
@@ -323,45 +327,30 @@ def train(self):
print(
f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}"
)
-
- # if global_rank in [0, None] and i >=warmup_iter:
- # self.throughput.update(
- # time=(t1-t0),
- # flops=self.model_flops,
- # batches=i,
- # samples=(i * self.micro_batch_size * self.gradient_accumulation_steps),
- # lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.model.max_seq_length),
- # )
-
- # metrics = self.throughput.compute()
- # if i % 10 == 0:
- # print(metrics)
+ if i >= self.warmup_iter:
+ self.throughput.update(
+ time=(t1 - t0),
+ flops=self.perf_metrics["model_flops"],
+ batches=i,
+ samples=(i * self.micro_batch_size * self.gradient_accumulation_steps),
+ lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size),
+ )
if global_rank in [0, None]:
# print(f"Total time: {(t1 - t0):.2f}s")
- # print(f"Average time per iter: {((t1 - t0)*1000)/(max_iters-warmup_iter):.2f}ms")
self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter)
def add_perf_metrics(self):
- # tokens_per_sec = total number of benchmarked iterations x global BS x block_size / total elapsed time (s)
- # = global BS x block_size / (total elapsed time (s)/total number of benchmarked iterations)
- # = global BS x block_size / average iter time (s)
- self.perf_metrics["tokens_per_sec"] = (
- self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["average_iter_time"]
- ) # tokens/s
- if self.perf_metrics["model_flops"] is not None:
- self.perf_metrics["model_flop_per_sec"] = (
- self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["average_iter_time"]
- )
- if world_size is not None:
- self.perf_metrics["model_flop_per_sec"] *= world_size
+ metrics = self.throughput.compute()
+ self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"])
+ self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"])
self.perf_metrics["memory_used_GB"] = torch.cuda.max_memory_allocated() / 1e9
def add_model_info_to_metrics(self):
if global_rank in [0, None]:
self.perf_metrics["model_name"] = self.model_name
self.perf_metrics["Num GPUS"] = world_size
- self.perf_metrics["Seq Len"] = self.model.max_seq_length
+ self.perf_metrics["Seq Len"] = self.config.block_size
self.perf_metrics["Micro BS"] = self.micro_batch_size
self.perf_metrics["Global BS"] = self.global_batch_size
self.perf_metrics["GA"] = self.gradient_accumulation_steps
@@ -413,7 +402,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
benchmark.add_perf_metrics()
print(
- f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.model.max_seq_length}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}"
+ f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.config.block_size}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}"
)
print(
f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B"
@@ -426,12 +415,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"Compiler: {benchmark.compile}")
print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms")
print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB")
- print(f"Throughput (Tokens/s): {benchmark.perf_metrics['tokens_per_sec']:.02f} tokens/s")
- print(
- f"Normalized Throughput (Tokens/s/GPU): {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f} tokens/s/gpu"
- )
- if benchmark.perf_metrics["model_flop_per_sec"] is not None:
- print(f"Model TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec']:.02f} TFLOP/s")
+ print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}")
+ print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
+ print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")
except Exception as error:
# Helps catch OutOfMemory Errors and post processing of errors
diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index 8d18d40904..772e65a84d 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -1237,11 +1237,6 @@ def _embedding_prim_grad(
def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable:
- # If executor specific `aug_fwd_rule` exists then we will use that,
- # so we return `None` here.
- if get_executor_specific_aug_fwd_rule(bsym):
- return None
-
cd = get_compile_data()
executors_list = cd.executors_list if cd is not None else executors_list
# Checks if the executor which has priority for this operation has a specific grad transform for it
@@ -2484,15 +2479,6 @@ def zeros_like(x):
}
-@dataclass(**default_dataclass_params)
-class RuleInfo:
- checker: Callable
- rule: Callable
- fw_fallback: Callable
- bw_fallback: Callable
- executor: Executor
-
-
def register_augmented_forward(op):
"""Decorator to register an augmented forward implementation for a symbol.
@@ -2510,40 +2496,6 @@ def decorator(func):
return decorator
-def register_augmented_forward_with_checker(executor, op, checker, rule):
- """Decorator to register an augmented forward implementation for a symbol.
-
- Args:
- executor (Executor): Executor to which the rule applies.
- op (Ops): Symbol for which to register the augmented forward implementation.
- checker (Callable): Function that checks if the rule should be applied.
- rule (Callable): Function that applies the rule.
- """
- fw_fallback = augmented_forward_impls.get(op, None)
- bw_fallback = backward_impls.get(op, None)
- augmented_forward_impls[executor, op] = RuleInfo(checker, rule, fw_fallback, bw_fallback, executor)
-
-
-def deregister_augmented_forward_and_backward(op):
- """Deregisters an augmented forward implementation and a backward
- implementation for a symbol.
-
- Args:
- op (Ops): Symbol for which to deregister the augmented forward
- implementation and the backward implementation.
-
- Returns:
- None
- """
- # Restore the fallback implementation if it exists
- if isinstance(augmented_forward_impls[op], RuleInfo):
- backward_impls[op] = augmented_forward_impls[op].bw_fallback
- augmented_forward_impls[op] = augmented_forward_impls[op].fw_fallback
- else:
- del augmented_forward_impls[op]
- del backward_impls[op]
-
-
def register_backward(op):
"""Decorator to register a backward implementation for a symbol.
@@ -3320,31 +3272,6 @@ def uniform_backward(primal, minval, maxval, g):
nondifferentiable_vjp_symbols = (prims.PrimIDs.BITWISE_AND, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL)
-def get_executor_specific_aug_fwd_rule(symbol: BoundSymbol) -> RuleInfo | None:
- """Get executor specific augmented forward rule.
-
- Args:
- symbol (BoundSymbol): BoundSymbol to get the rule for.
-
- Returns:
- RuleInfo: Rule info for the symbol.
- """
- cd = get_compile_data()
- if cd is None:
- return None
-
- # Search for the executor specific rules. When there are multiple rules
- # for the same symbol, we use the left-most executor in the list (i.e.
- # the one with the highest priority) and we fallback to the next one if
- # the checker returns False.
- for executor in cd.executors_list:
- candidate = augmented_forward_impls.get((executor, symbol.sym.id))
- if isinstance(candidate, RuleInfo) and candidate.checker(*symbol.args, **symbol.kwargs):
- return candidate
-
- return None
-
-
def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
"""Check if a symbol is constant for the VJP transform.
@@ -3387,19 +3314,10 @@ def vjp_impl_const(symbol, *args, **kwargs):
# Normal case, we have a proxy tangent
vjp_impl = augmented_forward_impls.get(symbol.sym.id)
- vjp_impl = get_executor_specific_aug_fwd_rule(symbol) or vjp_impl
if _get_gradfn(symbol) is not None:
vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)
- if isinstance(vjp_impl, RuleInfo):
- # We should use this rule only if checker returns True for the current
- # symbol's arguments
- if vjp_impl.checker(*symbol.args, **symbol.kwargs):
- vjp_impl = vjp_impl.rule
- else:
- vjp_impl = vjp_impl.fw_fallback
-
if vjp_impl is None:
# We could not find a VJP for this symbol, so we try to decompose it
if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs):
@@ -3567,14 +3485,10 @@ def put_grad(v: Variable, val: Any) -> None:
backward = backward_impls.get(symbol.sym.id)
aug_forward = augmented_forward_impls.get(symbol.sym.id)
- aug_forward = get_executor_specific_aug_fwd_rule(symbol) or aug_forward
if _get_gradfn(symbol) is not None:
aug_forward, backward = make_aug_forward_and_backward(symbol)
- if isinstance(aug_forward, RuleInfo):
- backward = backward_impls[aug_forward.executor, symbol.sym.id]
-
if backward is None:
if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs):
# We could not find a backward for this symbol, so we try to decompose it
@@ -3984,10 +3898,13 @@ def decorator(func):
def maybe_downcast_to(dtype, args):
allowed_downcast_types = (dtypes.float16, dtypes.bfloat16, dtypes.float32)
- if all(tree_map(lambda a: a.dtype in allowed_downcast_types, args)):
- return tree_map(lambda a: maybe_convert_to_dtype(a, dtype), args)
- else:
- return args
+
+ def map_fn(a):
+ if isinstance(a, TensorProxy) and a.dtype in allowed_downcast_types:
+ return maybe_convert_to_dtype(a, dtype)
+ return a
+
+ return tree_map(map_fn, args)
@register_autocast_rule("torch.matmul")
diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py
index 39ae65bda5..0a0eb943d9 100644
--- a/thunder/distributed/__init__.py
+++ b/thunder/distributed/__init__.py
@@ -68,9 +68,11 @@ def skip_data_parallel_grad_sync() -> None:
def _sync_grads(module: torch.nn.Module) -> None:
+ import thunder
+
params_with_grad = [p for p in module.parameters() if p.grad is not None]
grads = [p.grad for p in params_with_grad]
- process_group = module._lc_cd.process_group_for_ddp
+ process_group = thunder.compile_data(module).process_group_for_ddp
torch._foreach_div_(grads, process_group.size())
with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm:
for g in grads:
diff --git a/thunder/executors/apex_entropyex.py b/thunder/executors/apex_entropyex.py
index 8a82e04e20..818199ad5b 100644
--- a/thunder/executors/apex_entropyex.py
+++ b/thunder/executors/apex_entropyex.py
@@ -11,10 +11,6 @@
from thunder.core.symbol import Symbol
from thunder.core.utils import check, same_shape
from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, restore_reduced_dims
-from thunder.core.transforms import (
- register_augmented_forward_with_checker,
- register_backward,
-)
from thunder.extend import OperatorExecutor, register_executor
@@ -197,76 +193,6 @@ def _cross_entropy_checker(
return True
-# Check out the 'add vjp rule' dev tutorial on how to add a VJP rule for any
-# Symbol. We use our new primitives to register a VJP rule for
-# torch.nn.functional.cross_entropy. This function is registered as the
-# augmented forward rule for torch.nn.functional.cross_entropy below
-def apex_cross_entropy_forward_rule(
- a,
- target,
- weight=None,
- size_average=None,
- ignore_index=-100,
- reduce=None,
- reduction="mean",
- label_smoothing=0.0,
-):
- loss, max_log_sum_exp = apex_xentropy(
- a,
- target=target,
- reduction=reduction,
- label_smoothing=label_smoothing,
- )
- primal = loss
- saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing)
- return primal, saved_for_backward
-
-
-register_augmented_forward_with_checker(
- apex_ex,
- ltorch.cross_entropy.id,
- _cross_entropy_checker,
- apex_cross_entropy_forward_rule,
-)
-
-
-# This function is the backward rule for torch.nn.functional.cross_entropy. It
-# accepts the primal output and saved_for_backward from the forward pass and
-# returns the backward output. The backward output is a tuple of the backward
-# output for each differentiable Tensor input to the forward pass. In this case,
-# the forward pass has 1 such input, so the backward output is a single Tensor.
-# This function is registered as the backward rule for
-# torch.nn.functional.cross_entropy
-@register_backward((apex_ex, ltorch.cross_entropy.id))
-def apex_cross_entropy_backward_rule(
- logits,
- labels,
- max_log_sum_exp,
- reduction,
- smoothing,
- grad,
-):
- from thunder.core.transforms import mean_backward, sum_backward
-
- if reduction == "mean":
- grad = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), grad)
- elif reduction == "sum":
- grad = sum_backward(max_log_sum_exp.shape, (0,), grad)
- elif reduction == "none":
- pass
- else:
- raise ValueError(f"Invalid reduction: {reduction}")
-
- grad_logits = apex_xentropy_bwd(
- grad,
- logits,
- target=labels,
- max_log_sum_exp=max_log_sum_exp,
- label_smoothing=smoothing,
- )
- return grad_logits
-
-
# Translate calls from torch.nn.functional.cross_entropy to apex_xentropy (when the checker above returns True)
def _cross_entropy_transform(
a: TensorProxy,
diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py
index 9fb6c50a48..75494cff5a 100644
--- a/thunder/executors/cudnnex.py
+++ b/thunder/executors/cudnnex.py
@@ -35,8 +35,6 @@ def cudnn_available() -> bool:
get_grad,
put_grad,
put_grads,
- register_augmented_forward_with_checker,
- register_backward,
)
from thunder.extend import OperatorExecutor, register_executor
import thunder.torch as ltorch
@@ -338,7 +336,11 @@ def _cudnn_sdpa_forward_checker(
if d % 8 != 0 or d > 128:
return False
- return True
+ is_backward_supported = _cudnn_sdpa_backward_checker(
+ query, key, value, attn_mask, dropout_p, is_causal, scale=scale
+ )
+
+ return True and is_backward_supported
@langctx("torch")
@@ -601,99 +603,6 @@ def cudnn_sdpa_bwd_impl(
)
-@langctx("torch")
-def cudnn_sdpa_aug_fw_rule_checker(
- query: TensorProxy,
- key: TensorProxy,
- value: TensorProxy,
- attn_mask: None | TensorProxy,
- dropout_p: float,
- is_causal: bool,
- *,
- scale: None | float,
-) -> bool:
- from thunder.core.compile_data import get_compile_data
-
- cd = get_compile_data()
- if cudnn_ex in cd.executors_list:
- is_forward_supported = _cudnn_sdpa_forward_checker(
- query, key, value, attn_mask, dropout_p, is_causal, scale=scale
- )
- is_backward_supported = _cudnn_sdpa_backward_checker(
- query, key, value, attn_mask, dropout_p, is_causal, scale=scale
- )
- return is_forward_supported and is_backward_supported
- return False
-
-
-def cudnn_sdpa_aug_fw_rule(
- query,
- key,
- value,
- attn_mask=None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- *,
- scale: float | None = None,
-):
- output, softmax_stats, seed, offset = cudnn_sdpa_fwd(
- query, key, value, attn_mask, dropout_p, is_causal, scale=scale
- )
- saved_for_backward = (
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- scale,
- output,
- softmax_stats,
- seed,
- offset,
- )
- return output, saved_for_backward
-
-
-register_augmented_forward_with_checker(
- cudnn_ex,
- "torch.nn.functional.scaled_dot_product_attention",
- cudnn_sdpa_aug_fw_rule_checker,
- cudnn_sdpa_aug_fw_rule,
-)
-
-
-@register_backward((cudnn_ex, "torch.nn.functional.scaled_dot_product_attention"))
-def cudnn_sdpa_backward_rule(
- query: Proxy,
- key: Proxy,
- value: Proxy,
- attn_mask: None | Proxy,
- dropout_p: float,
- is_causal: bool,
- scale: None | float,
- out: Proxy,
- softmax_stats: Proxy,
- seed: Proxy,
- offset: Proxy,
- grad_out: Proxy,
-):
- return cudnn_sdpa_bwd(
- grad_out,
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- out,
- softmax_stats,
- seed,
- offset,
- scale=scale,
- )
-
-
@langctx("torch")
def _cudnn_sdpa_transform(
query: TensorProxy,
@@ -726,7 +635,7 @@ def _cudnn_sdpa_grad(
)
g = get_grad(primal)
- grad_query, grad_key, grad_val, grad_attn_mask = cudnn_sdpa_bwd(
+ grads = cudnn_sdpa_bwd(
g,
query,
key,
@@ -740,6 +649,11 @@ def _cudnn_sdpa_grad(
offset,
scale=scale,
)
+ if attn_mask is None:
+ grad_query, grad_key, grad_val = grads
+ else:
+ grad_query, grad_key, grad_val, grad_attn_mask = grads
+
put_grads((query, key, value), (grad_query, grad_key, grad_val))
if attn_mask is not None:
put_grad(attn_mask, grad_attn_mask)
diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py
index 005171e4e1..a2ff22007b 100644
--- a/thunder/executors/sdpaex.py
+++ b/thunder/executors/sdpaex.py
@@ -17,8 +17,6 @@
get_grad,
put_grad,
put_grads,
- register_augmented_forward_with_checker,
- register_backward,
)
from thunder.extend import OperatorExecutor, register_executor
diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py
index 91d32e7a88..ced5d8fdb1 100644
--- a/thunder/executors/transformer_engineex.py
+++ b/thunder/executors/transformer_engineex.py
@@ -19,10 +19,6 @@
import thunder.core.prims as prims
from thunder.core.proxies import TensorProxy, CollectionProxy
from thunder.core.symbol import Symbol
-from thunder.core.transforms import (
- register_augmented_forward_with_checker,
- register_backward,
-)
from thunder.extend import OperatorExecutor, register_executor
__all__ = [
@@ -411,15 +407,6 @@ def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | Ten
return False
-register_augmented_forward_with_checker(
- transformer_engine_ex,
- prims.linear.id,
- linear_forward_rule_checker,
- linear_forwad_rule,
-)
-
-
-@register_backward((transformer_engine_ex, prims.linear.id))
def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad):
return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx)
@@ -429,9 +416,21 @@ def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.T
return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False)
+def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy:
+ out, saved_for_backward = linear_forwad_rule(a, w, b)
+ g = prims.get_grad(out)
+ ga, gw, gb = linear_backward_rule(*saved_for_backward, g)
+ prims.put_grad(a, ga)
+ prims.put_grad(w, gw)
+ if b is not None:
+ prims.put_grad(b, gb)
+ return out
+
+
# Registers the implementation for torch.nn.functional.linear
transformer_engine_ex.register_implementation(
prims.linear,
checker=_linear_checker,
execution_transform=_linear_transform,
+ grad_transform=_linear_grad,
)
diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py
index 2d2ce1fd2a..2acdcf6427 100644
--- a/thunder/extend/__init__.py
+++ b/thunder/extend/__init__.py
@@ -25,6 +25,7 @@
"add_always_executor",
"remove_default_executor",
"remove_always_executor",
+ "register_lookaside",
]
@@ -208,6 +209,7 @@ def register_operator(
module: None | type | ModuleType = None,
fn: None | Callable = None,
bind_postprocess: None | Callable = None,
+ replaces: None | Callable = None,
python_printer: Callable = default_python_printer,
) -> Symbol:
assert (like is None) ^ (meta is None), "Expected one and only one of 'like' and 'meta' to be specified"
@@ -237,6 +239,9 @@ def _bind_postprocess(bsym: BoundSymbol) -> None:
)
self.opmap[name] = sym
+ if replaces is not None:
+ register_lookaside(replaces, sym)
+
return sym
def register_implementation(
@@ -381,3 +386,10 @@ def deregister_executor(ex: Hashable | Executor) -> None:
remove_always_executor(id)
remove_default_executor(id)
+
+
+def register_lookaside(function, symbol) -> None:
+ """register `symbol` as a lookaside for `function`"""
+ import thunder.core.jit_ext
+
+ thunder.core.jit_ext._general_jit_lookaside_map[function] = thunder.core.jit_ext.interpreter_needs_wrap(symbol)
diff --git a/thunder/py.typed b/thunder/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py
index 095c4e18b4..4cfb33316c 100644
--- a/thunder/tests/test_cudnn_executor.py
+++ b/thunder/tests/test_cudnn_executor.py
@@ -109,10 +109,8 @@ def test_cudnn_sdpa():
query = 1 * (torch.randn(shape_Q, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5)
key = 2 * (torch.randn(shape_K, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5)
value = 3 * (torch.randn(shape_V, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5)
- is_causal = False
- attn_mask = torch.randn(
- s_q, s_kv, requires_grad=False, device="cuda", dtype=thunder.torch.to_torch_dtype(dtype)
- )
+ is_causal = True
+ attn_mask = None
expected = torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, attn_mask=attn_mask
diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py
index 03277f3419..b7dd300d12 100644
--- a/thunder/tests/test_extend.py
+++ b/thunder/tests/test_extend.py
@@ -137,14 +137,17 @@ def test_register_implementation_custom_op():
myex = OperatorExecutor("myex", version="0.1")
register_executor(myex)
+ def official_add(a, b):
+ return a + b
+
def _myadd(a, b):
return a + b
- myadd1 = myex.register_operator("myadd1", like=_myadd, fn=_myadd)
+ myadd1 = myex.register_operator("myadd1", like=_myadd, fn=_myadd, replaces=official_add)
myadd2 = myex.register_operator("myadd2", like=_myadd, fn=_myadd)
def fn(a, b):
- return myadd1(a, b)
+ return official_add(a, b)
cfn = thunder.jit(fn, executors=[myex])
diff --git a/thunder/tests/test_transformer_engine_executor.py b/thunder/tests/test_transformer_engine_executor.py
index 60aa6f64ea..41b6f83e36 100644
--- a/thunder/tests/test_transformer_engine_executor.py
+++ b/thunder/tests/test_transformer_engine_executor.py
@@ -70,7 +70,8 @@ def fn(x, w1, w2):
assert_close(w2.grad, te_linear2.weight.grad)
# Verifies te_linear was called
- forward_trace, backward_trace = thunder.last_traces(cfn)
+ forward_trace = thunder.last_traces(cfn)
+ backward_trace = thunder.last_backward_traces(cfn)
assert any(bsym.sym.name.startswith("te_linear") for bsym in forward_trace[-1].bound_symbols)
assert any(bsym.sym.name.startswith("te_functional_linear_backward") for bsym in backward_trace[-1].bound_symbols)
@@ -180,6 +181,6 @@ def foo(x, w):
)
cfunc(x, w)
- fwd_traces, _ = thunder.last_traces(cfunc)
+ fwd_traces = thunder.last_traces(cfunc)
# Verify that we have replaced `prims.linear` with `te_linear`
assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols)