diff --git a/.azure/docker-build.yml b/.azure/docker-build.yml index 436a3a3a8b..73233ae78c 100644 --- a/.azure/docker-build.yml +++ b/.azure/docker-build.yml @@ -40,21 +40,24 @@ jobs: #maxParallel: "3" matrix: # CUDA 12.1 - 'cuda 12.1 | torch 2.2': - {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0'} - 'cuda 12.1 | torch 2.3 /nightly': - {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source'} + 'cuda 12.1 | torch 2.2 | cudnn FE v1.1': # todo: drop updating this image when CI transition to newer FE version + {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0', CUDNN_FRONTEND: "1.1.0"} + 'cuda 12.1 | torch 2.2 | cudnn FE v1.2': + {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0', CUDNN_FRONTEND: "1.2.0"} + 'cuda 12.1 | torch 2.3 /nightly | cudnn FE v1.1': # todo: drop updating this image when CI transition to newer FE version + {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source', CUDNN_FRONTEND: "1.1.0"} + 'cuda 12.1 | torch 2.3 /nightly | cudnn FE v1.2': + {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source', CUDNN_FRONTEND: "1.2.0"} #'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found - # how long to run the job before automatically cancelling - timeoutInMinutes: "95" # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" + timeoutInMinutes: "95" variables: UBUNTU_VERSION: '22.04' PYTHON_VERSION: '3.10' imageRepository: 'pytorchlightning/lightning-thunder' dockerfilePath: 'dockers/ubuntu-cuda/Dockerfile' - imageTag: 'ubuntu$(UBUNTU_VERSION)-cuda$(CUDA_VERSION)-py$(PYTHON_VERSION)-pt_${TORCH_VERSION/v/}' + imageTag: 'ubuntu$(UBUNTU_VERSION)-cuda$(CUDA_VERSION)-cudnn-fe$(CUDNN_FRONTEND)-py$(PYTHON_VERSION)-pt_${TORCH_VERSION/v/}' pool: 'lit-rtx-3090' workspace: clean: all @@ -74,11 +77,13 @@ jobs: -f $(dockerfilePath) \ --build-arg UBUNTU_VERSION="$(UBUNTU_VERSION)" \ --build-arg CUDA_VERSION="$(CUDA_VERSION)" \ + --build-arg CUDNN_FRONTEND_CHECKOUT="v$(CUDNN_FRONTEND)" \ --build-arg PYTHON_VERSION="$(PYTHON_VERSION)" \ --build-arg TORCH_VERSION="$(TORCH_VERSION)" \ --build-arg TRITON_VERSION="$(TRITON_VERSION)" \ --build-arg TORCH_INSTALL="$(TORCH_INSTALL)" \ . --no-cache + timeoutInMinutes: "95" displayName: 'Build base image' - bash: | @@ -98,6 +103,7 @@ jobs: echo $(DOCKERHUB_PAT) | docker login --username $(DOCKERHUB_USER) --password-stdin docker push $(imageRepository):$(imageTag) condition: ne(variables['Build.Reason'], 'PullRequest') + timeoutInMinutes: "35" displayName: 'Push base image' #- task: Docker@1 diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 387e47fc9c..22ba01eddc 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -17,17 +17,17 @@ jobs: matrix: # CUDA 12.1 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.2 | regular': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_2.2.1' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.1.0-py3.10-pt_2.2.1' CUDA_VERSION_MM: '121' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.2 | distributed': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_2.2.1' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.1.0-py3.10-pt_2.2.1' CUDA_VERSION_MM: '121' testing: 'distributed' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | regular': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_main' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.0-py3.10-pt_main' CUDA_VERSION_MM: '121' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | distributed': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-py3.10-pt_main' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.0-py3.10-pt_main' CUDA_VERSION_MM: '121' testing: 'distributed' # how long to run the job before automatically cancelling @@ -111,7 +111,7 @@ jobs: condition: eq(variables['testing'], 'distributed') displayName: 'Testing: distributed' - # todo for Mike as he promised some time ago already... or shall it ne another workflow so keep time low? + # todo (mruberry): decide whether this should be here or in another workflow #- bash: | # python benchmarks/ops_benchmark.py nanogpt-gelu # python benchmarks/nvfuser_benchmarks.py nanogpt-mlp -x thunder diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index 491d275173..ae28b10c53 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -16,14 +16,14 @@ jobs: # actions-ref: main check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.10.1 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.0 with: azure-dir: ".azure" check-package: - uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.10.1 + uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.0 with: - actions-ref: v0.10.1 + actions-ref: v0.11.0 import-name: "thunder" artifact-name: dist-packages-${{ github.sha }} testing-matrix: | diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 34751f2795..cbd8fb2aa6 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -79,7 +79,8 @@ jobs: - name: Install package & dependencies run: | pip --version - pip install -e '.[test]' -U \ + pip install -e . -U \ + -r requirements/test.txt \ --find-links=${TORCH_URL} ${PIP_EXTRA_FLAG} pip list shell: bash diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index a9c8d99b6f..d37e381b40 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -15,7 +15,7 @@ defaults: jobs: build-docs: - uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.10.1 + uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0 with: python-version: "3.10" requirements-file: "requirements/docs.txt" diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index e97fedf8e0..078f9e6066 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -27,7 +27,7 @@ jobs: # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.8.12 + uses: pypa/gh-action-pypi-publish@v1.8.14 with: user: __token__ password: ${{ secrets.test_pypi_password }} @@ -35,7 +35,7 @@ jobs: - name: Publish distribution 📦 to PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.8.12 + uses: pypa/gh-action-pypi-publish@v1.8.14 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/README.md b/README.md index 9593c3aea8..7d60063864 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,94 @@ +
+Thunder +
+
+ +**Make PyTorch models Lightning fast.** + +______________________________________________________________________ + +

+ Lightning.ai • + Performance • + Get started • + Install • + Examples • + Features • + Documentation • +

+ +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE) +[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml) +[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml) +[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main) + +
+ # Welcome to ⚡ Lightning Thunder -Lightning Thunder is a deep learning compiler for PyTorch. It makes PyTorch programs faster both on single accelerators or in distributed settings. +**Thunder makes PyTorch models Lightning fast.** + +Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8). + +Works on single accelerators and in multi-GPU settings. +Thunder aims to be usable, understandable, and extensible. + +## Performance + +Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt). + +
+Thunder +
+ +Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. + +Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway). -The main goal for Lightning Thunder is to allow optimizing user programs in the most extensible and expressive way possible. +
+Thunder +
-**NOTE: Lightning Thunder is alpha and not ready for production runs.** Feel free to get involved, expect a few bumps along the way. +**NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way. + +## Get started + +Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial). ## Install Thunder -Install the nvFuser nightly, which will also install the matching PyTorch nightly: +Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, and Thunder together ```bash -pip install --pre "nvfuser-cu121[torch]" --extra-index-url https://pypi.nvidia.com +# install nvFuser which installs the matching nightly PyTorch +pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com + +# install thunder +pip install lightning-thunder ``` -Install Thunder: +
+ Advanced install options + + +### Install from main ```bash pip install git+https://github.com/Lightning-AI/lightning-thunder.git ``` -or install from the local repo: +### Install to tinker and contribute + +Install this way to tinker with the internals and contribute: ```bash -pip install . +pip install -e . ``` +
+ + ## Hello World Here is a simple example of how Thunder lets you compile and run PyTorch code: @@ -56,11 +119,11 @@ print(result) The compiled function `jfoo` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of larger PyTorch programs. -## Running training +## Train models -Thunder is in its early stages, it should not be used for production runs yet. +Thunder is in its early stages and should not be used for production runs yet. -However, it can already deliver outstanding performance on models supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama2, Gemma, Falcon, and derivatives. +However, it can already deliver outstanding performance on LLM model supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others. Run training loop for Llama, single-GPU: @@ -76,25 +139,25 @@ python examples/lit-gpt/train_fsdp.py See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder. -## What's in the box +## Features -Given a program, Thunder can generate an optimized program that: +Given a Python callable or PyTorch module, Thunder can generate an optimized program that: -- computes its forward and backward passes -- coalesces operations into efficient fusion regions -- dispatches computations to optimized kernels -- distributes computations optimally across machines +- Computes its forward and backward passes +- Coalesces operations into efficient fusion regions +- Dispatches computations to optimized kernels +- Distributes computations optimally across machines To do so, Thunder ships with: -- a JIT for acquiring Python programs targeting PyTorch and custom operations -- a multi-level IR to represent them as a trace of a reduced op-set -- an extensible set of transformations on the trace, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`) -- a way to dispatch operations to an extensible collection of executors +- A JIT for acquiring Python programs targeting PyTorch and custom operations +- A multi-level IR to represent operations as a trace of a reduced op-set +- An extensible set of transformations on the trace, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`) +- A way to dispatch operations to an extensible collection of executors Thunder is written entirely in Python. Even its trace is represented as valid Python at all stages of transformation. This allows unprecedented levels of introspection and extensibility. -Thunder doesn't generate device code. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like: +Thunder doesn't generate code for accelerators directly. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like: - [torch.compile](https://pytorch.org/get-started/pytorch-2.0/) - [nvFuser](https://github.com/NVIDIA/Fuser) @@ -106,7 +169,7 @@ Thunder doesn't generate device code. It acquires and transforms user programs s Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations. -## Build the documentation +## Documentation Docs are currently not hosted publicly. However you can build them locally really quickly: @@ -141,9 +204,4 @@ Thunder is very thoroughly tested, so expect this to take a while. ## License Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. -See LICENSE file for details. - -[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml) -[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml) -[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest) -[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg?badge_token=mqheL1-cTn-280Vx4cJUdg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main?badge_token=mqheL1-cTn-280Vx4cJUdg) +See the [LICENSE](LICENSE) file for details. diff --git a/dockers/ubuntu-cuda/Dockerfile b/dockers/ubuntu-cuda/Dockerfile index 777d723697..e815d827f6 100644 --- a/dockers/ubuntu-cuda/Dockerfile +++ b/dockers/ubuntu-cuda/Dockerfile @@ -74,7 +74,7 @@ RUN \ RUN \ echo "CUDA_VERSION=$CUDA_VERSION ; CUDNN_VERSION=$CUDNN_VERSION " && \ CUDA_VERSION_MM=${CUDA_VERSION%.*} && \ - # there is missing cudnn for 12.1 so use 12.2 instead + # There are some test failures from cuDNN 12.1, so 'upgrade' requests for 12.1 to 12.2. CUDA_VERSION_MM="${CUDA_VERSION_MM/12.1/12.2}" && \ CUDNN_BASE_VER=${CUDNN_VERSION%%.*} && \ CUDNN_PACKAGE_VER="${CUDNN_VERSION}+cuda${CUDA_VERSION_MM}" && \ diff --git a/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png b/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png new file mode 100644 index 0000000000..831a23e233 Binary files /dev/null and b/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png differ diff --git a/docs/source/_static/images/normalized_training_throughput_zero2.png b/docs/source/_static/images/normalized_training_throughput_zero2.png new file mode 100644 index 0000000000..be6e5888c3 Binary files /dev/null and b/docs/source/_static/images/normalized_training_throughput_zero2.png differ diff --git a/docs/source/_static/images/training_throughput_single.png b/docs/source/_static/images/training_throughput_single.png new file mode 100644 index 0000000000..6c0a7029a4 Binary files /dev/null and b/docs/source/_static/images/training_throughput_single.png differ diff --git a/docs/source/advanced/inside_thunder.rst b/docs/source/advanced/inside_thunder.rst index 36013d66ae..2ae68509d8 100644 --- a/docs/source/advanced/inside_thunder.rst +++ b/docs/source/advanced/inside_thunder.rst @@ -8,9 +8,9 @@ Bytecode interpretation Thunder's interpreter works by: -1. disassembling the PyTorch module or function into CPython bytecode -2. interpreting the bytecode using an extended Python interpreter -3. generating a sequential trace of operations on tensors and numbers +1. Disassembling the PyTorch module or function into CPython bytecode +2. Interpreting the bytecode using an extended Python interpreter +3. Generating a sequential trace of operations on tensors and numbers Representing Operations ======================= diff --git a/docs/source/basic/mlp_mnist.rst b/docs/source/basic/mlp_mnist.rst index f4335c91f4..e7e81912c3 100644 --- a/docs/source/basic/mlp_mnist.rst +++ b/docs/source/basic/mlp_mnist.rst @@ -90,7 +90,8 @@ Here's the code:: # The training model has both "forward" and "backward" traces, corresponding # to its forward and backward computations. # The evaluation model has only one set of traces. - fwd_traces, bwd_traces = thunder.last_traces(jitted_train_model) + fwd_traces = thunder.last_traces(jitted_train_model) + bwd_traces = thunder.last_backward_traces(jitted_train_model) eval_traces = thunder.last_traces(jitted_eval_model) print("This is the trace that thunder executed for training's forward computation:") diff --git a/docs/source/basic/overview.rst b/docs/source/basic/overview.rst index cb272c0b9a..c456635e62 100644 --- a/docs/source/basic/overview.rst +++ b/docs/source/basic/overview.rst @@ -3,7 +3,7 @@ Thunder Overview This section introduces Thunder's core concepts and architecture. For more details, see :doc:`Inside thunder <../advanced/inside_thunder>`. -Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must be “valid” - it must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*. +Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*. This translation begins with:: @@ -13,7 +13,7 @@ or:: jitted_fn = thunder.jit(my_function) -When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a jitted function. +When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a function that when called will jit compile a path through the original function given information about the inputs. When the jitted module or function is called:: @@ -23,22 +23,23 @@ or:: jitted_fn(*args, **kwargs) -Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but this is actually required to produce a trace. Different inputs can produce different traces, since the operations called may different based on the properties of the input. -The trace is generated by running the bytecode through an extensible Python interpreter implemented in Python itself, that can be extended to perform instructions in a different way compared to what standard CPython does. As such, it can be instrumented to construct a trace of operations performed on tensors or numbers, and keep track of the provenance of all objects being part of the program. +As suggested above, Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but since control flow (and therefore the operations captured) may vary depending on the input, this is actually required to produce a trace. These traces are cached, so that if inputs of the same type, shape, etc are used again, the trace can be reused. -If replacing CPython with Python itself sounds problematic from a performance perspective, keep in mind that the initial interpretation of a deep learning program is typically amortized during the subsequent interpretations, due to the iterative nature of deep learning programs. In other words, if the meta data of inputs (like tensor shape) doesn't change and control-flow conditions are unchanged, then there's no point in constructing a new trace, and we can rely on smart caching to just execute a trace right away. +Traces are generated by running the bytecode through a custom Python interpreter, which is itself implemented in Python. This interpreter has been extended to perform instructions in a different way compared to what standard CPython does. In particular, it constructs a trace of operations performed on tensors or numbers, and keeps track of the provenance of all objects in the program, whether they originated from inside the interpreter or outside. -Traces don't typically deal with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators, but it records the operators along one path of the traceable function into the trace. +Much like other machine learning frameworks, Traces don't typically deal directly with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators. Instead, it records the operators along one path of the traceable function. -Traces can be transformed (like for backward) and optimized (like by replacing calls to PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process see the :doc:`thunder step by step ` section. +If replacing CPython with an interpreter written in Python sounds problematic from a performance perspective, you would be largely correct. We haven't yet put any time into optimizing it, and we think it consumes roughly 400x as much CPU time as CPython. However, the function only needs to be jitted once per equivalence class of inputs, and CPU is not a bottleneck in most machine learning pipelines. As long as the metadata of the inputs (such as a tensor's shape) and control flow conditions are not changed, we can rely on smart caching to immediately execute an optimized trace. The end result is a faster total execution time. + +Traces can be transformed (like for ``backward()``) and optimized (like by replacing calls to eager PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process, see the :doc:`thunder step by step ` section. To recap, the complete translation process is: -- For PyTorch modules, a Thunder-optimized module is created from the original module -- For PyTorch functions, compilation produces a compiled function -- When the module or function is called, the trace is generated, swapping some inputs with “proxies” -- The trace is transformed and optimized to produce an execution trace -- The execution trace is converted into a Python function and called +- For PyTorch modules, a Thunder-optimized module is created from the original module. +- For PyTorch functions, compilation produces a compiled function. +- When the module or function is called, the trace is generated, swapping some inputs with “proxies”. +- The trace is transformed and optimized to produce an execution trace. +- The execution trace is converted into a Python function and called. -As mentioned above, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times. +As mentioned, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times. diff --git a/docs/source/basic/sharp_edges.rst b/docs/source/basic/sharp_edges.rst index bf2b0abf16..d62590316b 100644 --- a/docs/source/basic/sharp_edges.rst +++ b/docs/source/basic/sharp_edges.rst @@ -10,12 +10,6 @@ Inplace operations Inplace PyTorch operations like `t.add_(1.0)` are not supported in Thunder yet. Support for inplace operations is coming soon. -Complex control flow --------------------- - -Control flow is supported in Thunder, but certain constructs might still be unsupported. - -In particular, attributes need to be resolved at tracing time for control flow to work. Data-dependent control flow, that is, when a condition depends on the value of tensors rather than its meta-data like shape or type, is currently not supported. Tensor subclasses ----------------- @@ -24,13 +18,14 @@ Thunder currently supports Python data types and PyTorch tensors as inputs of fu Subclasses of these types, e.g. lazy tensors, nested tensors, or sparse tensors are not supported today. + Tracing Python builtins, standard library operations and functions that call other languages -------------------------------------------------------------------------------------------- Calling a Python builtin, standard library operation, or a function that calls into another language is safe to trace, so long as the following rules are observed: -1. The function must not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement. -2. The function must not manipulate tensor metadata or data. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. +1. The function should not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement. +2. The function must not manipulate tensor data or metadata. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. To implement such operations, see :doc:`Adding Custom Operators <../notebooks/adding_custom_operator>` 3. The function must not produce different results across invocations. Again, since the operation won't appear in traces, Thunder cannot replicate an operation that produces different results when it's invoked, like ``random.random()`` will. .. diff --git a/docs/source/conf.py b/docs/source/conf.py index dcf2414b54..052fa2437c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,7 +49,10 @@ github_user = "Lightning-AI" github_repo = project -linkcheck_ignore = [rf"https://github.com/Lightning-AI/lightning-thunder(/.*|\.git)"] +linkcheck_ignore = [ + rf"https://github.com/Lightning-AI/lightning-thunder(/.*|\.git)", + rf"https://github.com/Lightning-AI/.*/blob/.*#.*", # github anchors are tricky +] # -- Project documents ------------------------------------------------------- diff --git a/docs/source/fundamentals/installation.rst b/docs/source/fundamentals/installation.rst index 7ee759ea82..8d41a24047 100644 --- a/docs/source/fundamentals/installation.rst +++ b/docs/source/fundamentals/installation.rst @@ -56,11 +56,11 @@ Thunder can easily integrate OpenAI Triton kernels. You can install Triton using Install Thunder =============== -You can now install Thunder +You can now install Thunder:: pip install git+https://github.com/Lightning-AI/lightning-thunder.git -Alternatively you can clone the Thunder repository and install locally +Alternatively you can clone the Thunder repository and install locally:: git clone https://github.com/Lightning-AI/lightning-thunder.git cd lightning-thunder diff --git a/docs/source/index.rst b/docs/source/index.rst index 804d5393b3..3ce2ca87c0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,7 +99,7 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik Additional executors Distributed Data Parallel What's next - FSDP Tutorial + FSDP Under the Hood Tutorial .. toctree:: :maxdepth: 1 @@ -110,7 +110,6 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik Extending thunder notebooks/adding_custom_operator notebooks/adding_custom_operator_backward - notebooks/adding_operator_executor .. toctree:: :maxdepth: 1 @@ -118,7 +117,6 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik :caption: Experimental dev tutorials notebooks/dev_tutorials/extend - notebooks/dev_tutorials/patterns .. TODO RC1: update notebooks diff --git a/docs/source/intermediate/additional_executors.rst b/docs/source/intermediate/additional_executors.rst index 76bb270bf0..911d7f16e9 100644 --- a/docs/source/intermediate/additional_executors.rst +++ b/docs/source/intermediate/additional_executors.rst @@ -10,11 +10,9 @@ Triton CrossEntropy Executor The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example:: + import torch import thunder - from thunder.executors import nvfuserex, torchex - from thunder.executors.triton_crossentropy import deregister_triton_entropyex, register_triton_entropyex - - register_triton_entropyex(add_to_default_executors=False) + from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex def xentropy(logits, labels, weight, reduction, ignore_index): return thunder.torch.cross_entropy( @@ -23,7 +21,7 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an jitted_xentropy = thunder.jit( xentropy, - executors_list=['triton_crossentropy', nvfuserex, torchex] + executors=[triton_cross_entropy_ex,] ) device = 'cuda' @@ -41,43 +39,42 @@ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an This prints:: - # Constructed by Delete Last Used + # Constructed by Delete Last Used (took 0 milliseconds) import torch + from thunder.executors.torchex import no_autocast + @torch.no_grad() - def xentropy(logits, labels, weight, reduction, ignore_index): + @no_autocast() + def computation(logits, labels, weight): # logits: "cuda:0 f32[2048, 50257]" # labels: "cuda:0 i64[2048]" # weight: "cuda:0 f32[50257]" - # "sum" - # ignore_index: "int 10106" - t22 = triton_cross_entropy(logits, labels, weight, None, ignore_index, None, "sum", 0.0) # t22: "cuda:0 f32[]" - del [logits, labels, weight, ignore_index] - return t22 + t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]" + del logits, labels, weight + return t23 -As shown in the above trace, ``triton_cross_entropy()`` is the one running the operation. +As shown in the above trace, ``triton_crossentropy()`` is the one running the operation. Apex CrossEntropy Executor ========================== The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this:: + import torch import thunder - from thunder.executors import nvfuserex, torchex - from thunder.executors.apex_entropyex import deregister_apex_entropyex, register_apex_entropyex - - register_apex_entropyex(add_to_default_executors=False) + from thunder.executors.apex_entropyex import apex_ex def xentropy(logits, labels): return thunder.torch.cross_entropy( logits, labels, reduction='mean', ignore_index=-1 ) - jitted_xentropy = thunder.jit(xentropy, executors_list=['apex_xentropy', nvfuserex, torchex]) + jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,]) device = 'cuda' dtype = torch.float32 - logits = torch.randn([2048, 50257], device=device, dtype=thunder.torch.to_torch_dtype(dtype)) + logits = torch.randn([2048, 50257], device=device, dtype=dtype) labels = torch.randint(0, 50257, [2048], device=device) jitted_xentropy(logits, labels) @@ -86,14 +83,17 @@ The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an This prints:: - # Constructed by Delete Last Used + # Constructed by Delete Last Used (took 0 milliseconds) import torch + from thunder.executors.torchex import no_autocast + @torch.no_grad() - def xentropy(logits, labels): + @no_autocast() + def computation(logits, labels): # logits: "cuda:0 f32[2048, 50257]" # labels: "cuda:0 i64[2048]" - t18 = apex_cross_entropy(logits, labels, None, None, -1, None, "mean", 0.0) # t18: "cuda:0 f32[]" - del [logits, labels] + (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0) + del logits, labels return t18 showing that Apex is running the operation. diff --git a/docs/source/reference/common/index.rst b/docs/source/reference/common/index.rst index 0011c1b21f..20e0144d42 100644 --- a/docs/source/reference/common/index.rst +++ b/docs/source/reference/common/index.rst @@ -9,4 +9,3 @@ Common functions and classes for Thunder. :toctree: generated/ CACHE_OPTIONS - preprocess diff --git a/docs/source/reference/thunder.rst b/docs/source/reference/thunder.rst index ced14ed4e8..706f9d6f08 100644 --- a/docs/source/reference/thunder.rst +++ b/docs/source/reference/thunder.rst @@ -25,6 +25,7 @@ Querying information on compiled functions and modules compile_data compile_stats last_traces + last_backward_traces last_prologue_traces cache_option cache_hits diff --git a/examples/lit-gpt/_ddp_thunder.py b/examples/lit-gpt/_ddp_thunder.py index 8d53a567a4..1bd07619df 100644 --- a/examples/lit-gpt/_ddp_thunder.py +++ b/examples/lit-gpt/_ddp_thunder.py @@ -199,8 +199,8 @@ def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: f" Got: {module.__class__.__name__}." ) - # see https://github.com/Lightning-AI/lightning-thunder/issues/2085 - # for why we cannot just return `module.no_sync()` + # issue "Limitations of the current DDP no_sync implementation" has + # details on why we cannot just return `module.no_sync()` from thunder.distributed import skip_data_parallel_grad_sync previous, self._enabled = self._enabled, enabled diff --git a/examples/lit-gpt/_fsdp_thunder.py b/examples/lit-gpt/_fsdp_thunder.py index 77ad8cfba0..133c40b1f2 100644 --- a/examples/lit-gpt/_fsdp_thunder.py +++ b/examples/lit-gpt/_fsdp_thunder.py @@ -414,8 +414,7 @@ def _get_state_dict( def _unwrap_tom(obj: object) -> object: # TODO: this unwrap won't be required when Fabric's `_unwrap_objects` supports Thunder from thunder import ThunderModule - from thunder.common import ThunderOptimizedModule - if isinstance(obj, (ThunderOptimizedModule, ThunderModule)): + if isinstance(obj, ThunderModule): return obj._model return obj diff --git a/examples/lit-gpt/test_parametrized.py b/examples/lit-gpt/test_parametrized.py index 20ddaa9278..5e658b6447 100644 --- a/examples/lit-gpt/test_parametrized.py +++ b/examples/lit-gpt/test_parametrized.py @@ -7,20 +7,18 @@ MID_BENCHMARK_OUT - use this env variable to control whether you want to see the combined results between each test. BENCHMARK_OUT_FORMAT - use this env variable to control the format in which the results are presented. - Uses 'xlsx' by default. More format support to come soon. + Uses 'xlsx' by default. Supported: 'none', 'print', 'xlsx'. ''' import torch from absl.testing import parameterized from absl.testing import absltest +from collections import defaultdict import os -import pickle import subprocess -import warnings import json import pandas as pd from datetime import datetime -import threading class Runner: ''' @@ -51,6 +49,9 @@ def add_to_dataframe(self): self.dataframe_data.append(self.perf_metrics_dict) def complete_dataframe(self, is_teardown): + if not self.dataframe_data: + # The benchmark probably failed + return #Called when tearing down the parametrized test #This generates a summarized dataframe for each perf metric and saves as a xlsx file df = pd.DataFrame(self.dataframe_data) @@ -62,14 +63,14 @@ def complete_dataframe(self, is_teardown): self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index() self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index() - if self.output_format not in ('none', 'print'): + if self.output_format == "xlsx": output_ext = {'xlsx': '.xlsx', }[self.output_format] if not is_teardown: - filename = '/scratch/lightning-thunder/examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) + filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) else: current_time = datetime.now().strftime('%Y-%m-%d_%H-%M') filename = f"{current_time}_litgpt_benchmark" + str(output_ext) - filename = '/scratch/lightning-thunder/examples/lit-gpt/' + str(filename) + filename = 'examples/lit-gpt/' + str(filename) with pd.ExcelWriter(filename, engine='xlsxwriter') as writer: self.iter_time_df.to_excel(writer, sheet_name='Average Iter Time (ms)') @@ -87,41 +88,40 @@ def complete_dataframe(self, is_teardown): print(self.memory_used_GB_df) def run_benchmark(self, kwargs): - # benchmark_file = '/scratch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py' command_list = [] for key, val in kwargs.items(): command_list.append("--" + str(key) + "=" + str(val)) if kwargs['distributed_mode'] != 'none': - subprocess_cmd = ["torchrun", "--nproc_per_node=8", "--nnodes=1", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + nproc_per_node = torch.cuda.device_count() + subprocess_cmd = ["torchrun", f"--nproc_per_node={nproc_per_node}", "--nnodes=1", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] subprocess_cmd.extend(command_list) else: - subprocess_cmd = ["python", "{}".format(benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + subprocess_cmd = ["python", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] subprocess_cmd.extend(command_list) print(f'Running {" ".join(subprocess_cmd)!r}') proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) - with open(self.json_file_path, 'r') as file: - self.perf_metrics_dict = json.load(file) - os.remove(self.json_file_path) #cleanup after test finishes - - if self.perf_metrics_dict['average_iter_time'] is None: - if 'CUDA out of memory' in proc_output.stdout: - self.perf_metrics_dict['average_iter_time'] = 'OOM' - self.perf_metrics_dict['model_flops'] = 'OOM' - self.perf_metrics_dict['model_flop_per_sec'] = 'OOM' - self.perf_metrics_dict['tokens_per_sec'] = 'OOM' - self.perf_metrics_dict['tokens_per_sec_per_gpu'] = 'OOM' - self.perf_metrics_dict['memory_used_GB'] = 'OOM' + self.perf_metrics_dict = {} + if os.path.exists(self.json_file_path): + with open(self.json_file_path, 'r') as file: + self.perf_metrics_dict = json.load(file) + # Cleanup after the benchmark finishes. It might have failed before creating this + os.remove(self.json_file_path) + + if proc_output.returncode: + if 'CUDA out of memory' in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr: + defaultdict_oom = defaultdict(lambda: "OOM") + defaultdict_oom.update(self.perf_metrics_dict) + self.perf_metrics_dict = defaultdict_oom pass_str = "TestCase did not finish reporting metrics due to CUDA out of memory error. Reporting OOM and triggering test success." return True, pass_str - else: - fail_str = "Testcase did not finish reporting metrics due to an unknown error. Triggering test failure." - return False, fail_str - else: - return True, "Test passed successfully." - # print(proc_output.stdout) - # print(proc_output.stderr) + print(proc_output.stdout) + print(proc_output.stderr) + fail_str = "TestCase did not finish reporting metrics due to an unknown error. Triggering test failure." + return False, fail_str + return True, "Test passed successfully." + class Test(parameterized.TestCase): @@ -152,12 +152,12 @@ def tearDownClass(cls): # dict(distributed_mode = "none", shard_mode = "none")), # (dict(model_name = 'Llama-2-7b-hf', micro_batch_size=1), # dict(model_name = 'Llama-2-7b-hf', micro_batch_size=2), - # dict(model_name = 'Llama-2-13b{}-hf', micro_batch_size=1), - # dict(model_name = 'Llama-2-13b{}-hf', micro_batch_size=2), + # dict(model_name = 'Llama-2-13b-hf', micro_batch_size=1), + # dict(model_name = 'Llama-2-13b-hf', micro_batch_size=2), # dict(model_name = 'stablecode-completion-alpha-3b', micro_batch_size=1), # dict(model_name = 'stablecode-completion-alpha-3b', micro_batch_size=2), - # dict(model_name = 'Mistral-7B-{}v0.1', micro_batch_size=1), - # dict(model_name = 'Mistral-7B-{}v0.1', micro_batch_size=2), + # dict(model_name = 'Mistral-7B-v0.1', micro_batch_size=1), + # dict(model_name = 'Mistral-7B-v0.1', micro_batch_size=2), # dict(model_name = 'open_llama_3b', micro_batch_size=1), # dict(model_name = 'open_llama_3b', micro_batch_size=2), # dict(model_name = 'open_llama_3b', micro_batch_size=4), @@ -178,8 +178,8 @@ def tearDownClass(cls): # dict(model_name = 'pythia-6.9b', micro_batch_size=2), # dict(model_name = 'pythia-12b', micro_batch_size=1), # dict(model_name = 'pythia-12b', micro_batch_size=2), - # dict(model_name = 'falcon-7b{}', micro_batch_size=1), - # dict(model_name = 'falcon-7b{}', micro_batch_size=2)), + # dict(model_name = 'falcon-7b', micro_batch_size=1), + # dict(model_name = 'falcon-7b', micro_batch_size=2)), # compile = ("eager", "inductor", "thunder", "thunder_inductor",) # ) diff --git a/examples/lit-gpt/train.py b/examples/lit-gpt/train.py index bf5b5e1e6e..412711ce5a 100644 --- a/examples/lit-gpt/train.py +++ b/examples/lit-gpt/train.py @@ -15,7 +15,7 @@ def main(compile: str = "eager", dynamic: bool = False) -> None: fabric = L.Fabric(devices=1, precision="bf16-true") - fabric.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP) + fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) config = Config.from_name(model_name) print(f"Loading model with {config.__dict__}") diff --git a/examples/lit-gpt/train_fsdp.py b/examples/lit-gpt/train_fsdp.py index c855d61b92..e896d52ef3 100644 --- a/examples/lit-gpt/train_fsdp.py +++ b/examples/lit-gpt/train_fsdp.py @@ -38,7 +38,7 @@ def main( fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-true") fabric.launch() - fabric.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP) + fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) config = Config.from_name(model_name) fabric.print(f"Loading model with {config.__dict__}") diff --git a/examples/llama2.c/README.md b/examples/llama2.c/README.md index 5acb8f0742..4a999da840 100644 --- a/examples/llama2.c/README.md +++ b/examples/llama2.c/README.md @@ -28,9 +28,9 @@ The code is configured to run with Thunder by default. Results with 1 GPU: -- ~339 ms/iter (torch.compile 'inductor') -- ~347 ms/iter (thunder nvfuser) -- ~431 ms/iter (eager) +- ~215 ms/iter (torch.compile 'inductor') +- ~239 ms/iter (thunder nvfuser) +- ~339 ms/iter (eager) CUDAGraphs are not used as the results were worse with them. @@ -46,15 +46,14 @@ nanoGPT doesn't implement KV caching so this is expectedly slow. Please checkout ## Setup ```text -Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime) +Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Is debug build: False -CUDA used to build PyTorch: 12.1 -CUDA runtime version: 12.1.105 +CUDA used to build PyTorch: 12.4 +CUDA runtime version: 12.4.99 GPU 0: NVIDIA A100-SXM4-40GB -Nvidia driver version: 525.125.06 +Nvidia driver version: 550.54.14 -pytorch-triton @ https://download.pytorch.org/whl/nightly/pytorch_triton-3.0.0%2B901819d2b6-cp310-cp310-linux_x86_64.whl -torch @ https://download.pytorch.org/whl/nightly/cu121/torch-2.3.0.dev20240130%2Bcu121-cp310-cp310-linux_x86_64.whl -lightning-thunder==8b107c6fe531c94c6705dbf39700863685ba5b65 -nvfuser_cu121==0.1.5.dev20240131 +triton == 3.0.0 +torch == 2.4.0a0+git685ace3 +nvfuser @ 0.2.0+git70101da ``` diff --git a/examples/llama2.c/model.py b/examples/llama2.c/model.py index aaf4aad819..297af9e1f6 100644 --- a/examples/llama2.c/model.py +++ b/examples/llama2.c/model.py @@ -65,7 +65,6 @@ def apply_rotary_emb( xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting - # https://github.com/Lightning-AI/lightning-thunder/issues/1106 a, b = freqs_cos.shape freqs_cos = freqs_cos.view(1, a, 1, b) freqs_sin = freqs_sin.view(1, a, 1, b) @@ -244,7 +243,7 @@ def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - # https://github.com/Lightning-AI/lightning-thunder/issues/1108 + # see issue "Unexpected KeyError when self attribute is set inside forward" #self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position @@ -258,7 +257,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): param_dict = {pn: p for pn, p in self.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # create optim groups. Any parameter that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] diff --git a/examples/llama2.c/sample.py b/examples/llama2.c/sample.py index 9184340203..094b6c3f74 100644 --- a/examples/llama2.c/sample.py +++ b/examples/llama2.c/sample.py @@ -20,11 +20,10 @@ temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability tokenizer = "" # override the tokenizer model path -seed = 1337 +seed = 42 device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' -compile = True # Use lightning.compile to compile the model to be faster +compile = True # Use thunder.jit to compile the model to be faster exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- @@ -33,7 +32,6 @@ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 # ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() # torch.amp.autocast(device_type=device_type, dtype=ptdtype) @@ -57,10 +55,10 @@ from thunder.executors.sdpaex import sdpa_ex executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor] - cmodel = thunder.compile(model, disable_torch_autograd_support=True, executors_list=executors) + cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors=executors) # the generate implementation is not compile friendly, so bind the compiled model to the generate implementation generate = partial(Transformer.generate, cmodel) - # workaround for https://github.com/Lightning-AI/lightning-thunder/issues/954 + # workaround for "Foward nn.Module attributes through the ThunderOptimizedModule" cmodel.params = model.params else: generate = model.generate diff --git a/examples/llama2.c/tinystories.py b/examples/llama2.c/tinystories.py index cafc1b164a..5ef5c6a247 100644 --- a/examples/llama2.c/tinystories.py +++ b/examples/llama2.c/tinystories.py @@ -191,7 +191,7 @@ def __iter__(self): # get DDP rank info rank = dist.get_rank() if dist.is_initialized() else 0 # combine the worker_id and worker_rank to create a unique seed for rng - seed = 42 + worker_id + 1337 * rank + seed = 42 + worker_id + 1942 * rank rng = random.Random(seed) print(f"Created a PretokDataset with rng seed {seed}") if self.vocab_source == "llama2": diff --git a/examples/llama2.c/train.py b/examples/llama2.c/train.py index 18290df075..206a4e065d 100644 --- a/examples/llama2.c/train.py +++ b/examples/llama2.c/train.py @@ -70,9 +70,8 @@ warmup_iters = 1000 # how many steps to warm up for # system device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 -# dtype = "bfloat16" # float32|bfloat16|float16 -compile = "thunder" # eager|torch|thunder +dtype = "bfloat16" # float32|bfloat16|float16 +compile = "thunder" # thunder|torch|eager # ----------------------------------------------------------------------------- config_keys = [ k @@ -118,14 +117,20 @@ if master_process: os.makedirs(out_dir, exist_ok=True) -torch.manual_seed(1337 + seed_offset) +torch.manual_seed(42 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast # note: float16 data type will automatically use a GradScaler -# thunder does not support autocast: https://github.com/Lightning-AI/lightning-thunder/issues/491 -# ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] -ctx = nullcontext() # torch.amp.autocast(device_type=device_type, dtype=ptdtype) +ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] +ctx = ( + nullcontext() + if device_type == "cpu" + else torch.amp.autocast(device_type=device_type, dtype=ptdtype) +) +# Disable other than FlashAttention backends for SDPA +torch.backends.cuda.enable_math_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) # task-specific setup iter_batches = partial( @@ -181,10 +186,11 @@ model.load_state_dict(state_dict) iter_num = checkpoint["iter_num"] best_val_loss = checkpoint["best_val_loss"] + model.to(device) # initialize a GradScaler. If enabled=False scaler is a no-op -scaler = torch.cuda.amp.GradScaler(enabled=(False)) # dtype == "float16")) +scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) @@ -192,19 +198,19 @@ optimizer.load_state_dict(checkpoint["optimizer"]) checkpoint = None # free up memory -raw_model = eval_model = train_model = model +raw_model = model # wrap model into DDP container if ddp: if compile == "thunder": from thunder.distributed import ddp - train_model = ddp(train_model) + model = ddp(model) else: # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at # construction time since NCCL does not support `ComplexFloat` - train_model._ddp_params_and_buffers_to_ignore = {"freqs_cis"} - train_model = DDP(train_model, device_ids=[ddp_local_rank]) + model._ddp_params_and_buffers_to_ignore = {"freqs_cis"} + model = DDP(model, device_ids=[ddp_local_rank]) # compile the model if compile == "thunder": @@ -214,31 +220,29 @@ from thunder.executors.sdpaex import sdpa_ex executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor] - eval_model = thunder.compile(eval_model.eval(), disable_torch_autograd_support=True, executors_list=executors) - train_model = thunder.compile(train_model.train(), executors_list=executors) + model = thunder.jit(model, executors=executors) elif compile == "torch": print("compiling the model with torch... (takes a ~minute)") - eval_model = torch.compile(eval_model) - train_model = torch.compile(train_model) + model = torch.compile(model) # helps estimate an arbitrarily accurate loss over either split using many batches @torch.no_grad() def estimate_loss(): out = {} if compile != "thunder": - eval_model.eval() + model.eval() for split in ["train", "val"]: batch_iter = iter_batches(split=split) losses = torch.zeros(eval_iters) # keep on CPU for k in range(eval_iters): X, Y = next(batch_iter) with ctx: - logits = eval_model(X, Y) + logits = model(X, Y) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1) losses[k] = loss.item() out[split] = losses.mean() if compile != "thunder": - train_model.train() + model.train() return out # learning rate decay scheduler (cosine with warmup) @@ -313,11 +317,11 @@ def get_lr(it): if ddp: # in DDP training we only need to sync gradients at the last micro step. # the official way to do this is with model.no_sync() context manager, but - # I really dislike that this bloats the code and forces us to repeat code + # this forces us to repeat code. # looking at the source of that context manager, it just toggles this variable - train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 + model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: - logits = train_model(X, Y) + logits = model(X, Y) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1) loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU @@ -327,7 +331,7 @@ def get_lr(it): # clip the gradient if grad_clip != 0.0: scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(train_model.parameters(), grad_clip) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # step the optimizer and scaler if training in fp16 scaler.step(optimizer) scaler.update() diff --git a/notebooks/.ignore.ci b/notebooks/.ignore.ci index 10cac1cc86..8655f707bb 100644 --- a/notebooks/.ignore.ci +++ b/notebooks/.ignore.ci @@ -1,5 +1,2 @@ -adding_custom_operator.ipynb adding_custom_operator_backward.ipynb -adding_operator_executor.ipynb dev_tutorials/extend.ipynb -dev_tutorials/patterns.ipynb diff --git a/notebooks/adding_custom_operator.ipynb b/notebooks/adding_custom_operator.ipynb index c293bc3167..7515a36c15 100644 --- a/notebooks/adding_custom_operator.ipynb +++ b/notebooks/adding_custom_operator.ipynb @@ -26,6 +26,14 @@ "from enum import Enum" ] }, + { + "cell_type": "markdown", + "id": "a1b6863a", + "metadata": {}, + "source": [ + "Let us define some helper functions (execute the cell below) for printing what's going on." + ] + }, { "cell_type": "code", "execution_count": 2, @@ -33,7 +41,6 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Helper functions (execute this cell)\n", "import functools\n", "\n", "_indentation = 0\n", @@ -83,135 +90,92 @@ }, { "cell_type": "markdown", - "id": "a06c6260", + "id": "c8e1626f", "metadata": {}, "source": [ "Our new operator has the following signature `sincos(x: Tensor) -> Tuple[Tensor, Tensor]`. It takes a tensor as input and returns a tuple of two tensors. The first tensor is the sine of the input and the second tensor is the cosine of the input.\n", "\n", - "We call all callables that should be recorded in the trace Symbols. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite perators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.\n", + "We call all callables that should be recorded in the trace *Symbols*. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite perators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.\n", + "\n", + "The easiest way to register a new operator is through defining a meta - defining how the metadata of the output looks like give the metadata of the inputs and an implementation (dealing with concrete objects like Python `Number`s and PyTorch `Tensor`s) and register both of them through an executor. This will automatically create a symbol for us.\n", "\n", - "Let's create a new Symbol called `sincos` and implement it in Python." + "So we create an executor:" ] }, { "cell_type": "code", "execution_count": 3, - "id": "764c203a", + "id": "f680ae37", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Help on class Symbol in module thunder.core.symbol:\n", - "\n", - "class Symbol(builtins.object)\n", - " | Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = , _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False) -> None\n", - " | \n", - " | Symbol(name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = , _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False)\n", - " | \n", - " | Methods defined here:\n", - " | \n", - " | __call__(self, *args, **kwargs)\n", - " | Call self as a function.\n", - " | \n", - " | __delattr__(self, name)\n", - " | Implement delattr(self, name).\n", - " | \n", - " | __eq__(self, other: 'Symbol') -> 'int'\n", - " | Return self==value.\n", - " | \n", - " | __getstate__ = _dataclass_getstate(self)\n", - " | # _dataclass_getstate and _dataclass_setstate are needed for pickling frozen\n", - " | # classes with slots. These could be slightly more performant if we generated\n", - " | # the code instead of iterating over fields. But that can be a project for\n", - " | # another day, if performance becomes an issue.\n", - " | \n", - " | __hash__(self) -> 'int'\n", - " | Return hash(self).\n", - " | \n", - " | __init__(self, name: 'str', meta: 'Callable | None' = None, python_impl: 'Callable | None' = None, id: 'Any | None' = None, is_prim: 'bool' = False, is_fusion: 'bool' = False, python_printer: 'Callable' = , _module: 'Any | None' = None, _hash: 'Optional[int]' = None, _bind_postprocess: 'None | Callable' = None, _phantom: 'bool' = False) -> None\n", - " | Initialize self. See help(type(self)) for accurate signature.\n", - " | \n", - " | __repr__(self) -> 'str'\n", - " | Return repr(self).\n", - " | \n", - " | __setattr__(self, name, value)\n", - " | Implement setattr(self, name, value).\n", - " | \n", - " | __setstate__ = _dataclass_setstate(self, state)\n", - " | \n", - " | bind(self, *args, output, subsymbols=(), _call_ctx=None, **kwargs) -> 'BoundSymbol'\n", - " | \n", - " | name_with_module(self)\n", - " | \n", - " | normalize(self, *args, **kwargs)\n", - " | \n", - " | ----------------------------------------------------------------------\n", - " | Readonly properties defined here:\n", - " | \n", - " | module\n", - " | \n", - " | ----------------------------------------------------------------------\n", - " | Data descriptors defined here:\n", - " | \n", - " | __weakref__\n", - " | list of weak references to the object (if defined)\n", - " | \n", - " | id\n", - " | \n", - " | is_fusion\n", - " | \n", - " | is_prim\n", - " | \n", - " | meta\n", - " | \n", - " | name\n", - " | \n", - " | python_impl\n", - " | \n", - " | python_printer\n", - " | \n", - " | ----------------------------------------------------------------------\n", - " | Data and other attributes defined here:\n", - " | \n", - " | __annotations__ = {'_bind_postprocess': 'None | Callable', '_hash': 'O...\n", - " | \n", - " | __dataclass_fields__ = {'_bind_postprocess': Field(name='_bind_postpro...\n", - " | \n", - " | __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,or...\n", - " | \n", - " | __match_args__ = ('name', 'meta', 'python_impl', 'id', 'is_prim', 'is_...\n", - "\n" - ] + "data": { + "text/plain": [ + "[sincos_executor, sdpa]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from thunder.core.symbol import Symbol\n", - "\n", - "help(Symbol)" + "sincos_executor = thunder.extend.OperatorExecutor(\"sincos_executor\", version='0.1')\n", + "thunder.add_default_executor(sincos_executor)" + ] + }, + { + "cell_type": "markdown", + "id": "4f147274", + "metadata": {}, + "source": [ + "We define meta and implementation: " ] }, { "cell_type": "code", "execution_count": 4, - "id": "ba10b306", + "id": "d5a72aff", "metadata": {}, "outputs": [], "source": [ "@log\n", - "def sincos_meta(input):\n", - " return (TensorProxy(like=input), TensorProxy(like=input))\n", + "def sincos_meta(inp):\n", + " return (TensorProxy(like=inp), TensorProxy(like=inp))\n", "\n", - "class CustomOps(Enum):\n", - " sincos = 0\n", - "\n", - "sincos = Symbol(\n", - " id=CustomOps.sincos,\n", - " name=\"sincos\",\n", - " meta=sincos_meta,\n", - " is_prim=True,\n", - ")" + "@log\n", + "def sincos_impl(inp):\n", + " return torch.sin(inp), torch.cos(inp)" + ] + }, + { + "cell_type": "markdown", + "id": "a06c6260", + "metadata": {}, + "source": [ + "And register it as `sincos`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "03516b03", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Symbol name=sincos]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sincos = sincos_executor.register_operator('sincos', meta=sincos_meta, fn=sincos_impl)\n", + "sincos" ] }, { @@ -224,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "8c5da6f2", "metadata": {}, "outputs": [], @@ -236,13 +200,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "aef98360", "metadata": {}, "outputs": [], "source": [ - "a = torch.randn(1, device=\"cuda\")\n", - "b = torch.randn(1, device=\"cuda\")" + "a = torch.randn(1)\n", + "b = torch.randn(1)" ] }, { @@ -255,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "87f9f6e7", "metadata": {}, "outputs": [ @@ -263,7 +227,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Couldn't find an eager implementation for sincos\n" + "Attempting to execute outside of a tracing context, which is not supported\n" ] } ], @@ -284,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "f938dff7-bac6-4807-b79d-a16cb5c6d90c", "metadata": {}, "outputs": [ @@ -292,23 +256,25 @@ "name": "stdout", "output_type": "stream", "text": [ - "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))\n", + "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "# import __main__ as __main__\n", - "# import thunder as thunder\n", - "# import thunder.torch as ltorch\n", + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import thunder\n", + "import thunder.torch as ltorch\n", "import torch\n", + "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", + "@no_autocast()\n", "def fun(a, b):\n", - " # a: \"cuda:0 f32[1]\" \n", - " # b: \"cuda:0 f32[1]\" \n", - " (t0, t1) = __main__.sincos(a)\n", - " t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cuda:0 f32[1]\"\n", - " # t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"\n", - " t3 = ltorch.add(t2, b, alpha=None) # t3: \"cuda:0 f32[1]\"\n", - " # t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"\n", + " # a: \"cpu f32[1]\" \n", + " # b: \"cpu f32[1]\" \n", + " (t0, t1) = sincos(a)\n", + " t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[1]\"\n", + " # t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"\n", + " t3 = ltorch.add(t2, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"\n", " return t3\n" ] } @@ -321,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "6eb4818b", "metadata": {}, "outputs": [ @@ -329,17 +295,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: \"cuda:0 f32[1]\" |\n", - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: \"cuda:0 f32[1]\" |\n", - "Bound symbol with id=CustomOps.sincos is represented in the trace as |(t0, t1) = __main__.sincos(a)|\n", - "Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cuda:0 f32[1]\"\n", - " # t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"|\n", + "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: \"cpu f32[1]\" |\n", + "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: \"cpu f32[1]\" |\n", + "Bound symbol with id=sincos is represented in the trace as |(t0, t1) = sincos(a)|\n", + "Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[1]\"\n", + " # t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"|\n", " It has the following subsymbols:\n", - " id=PrimIDs.ADD |t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"|\n", - "Bound symbol with id=torch.add is represented in the trace as |t3 = ltorch.add(t2, b, alpha=None) # t3: \"cuda:0 f32[1]\"\n", - " # t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"|\n", + " id=PrimIDs.ADD |t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"|\n", + "Bound symbol with id=torch.add is represented in the trace as |t3 = ltorch.add(t2, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"|\n", " It has the following subsymbols:\n", - " id=PrimIDs.ADD |t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"|\n", + " id=PrimIDs.ADD |t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"|\n", "Bound symbol with id=PrimIDs.RETURN is represented in the trace as |return t3|\n" ] } @@ -364,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "41566de2-a60f-4c87-a3d6-58e6a89dc38b", "metadata": {}, "outputs": [], @@ -374,151 +340,292 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "bbbb90c2", + "execution_count": 12, + "id": "24af4b99", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))\n", + "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "Could not find executor for bound symbol (t0, t1) = __main__.sincos(a)\n" + "call sincos_impl(Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1413]))\n", + "|<- sincos_impl = (Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1408]), Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.9900]))\n", + "\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NoneType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type bool, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type SequenceIter, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type int, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NotImplementedType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n", + "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type StopIteration, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", + " warnings.warn(s)\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([0.7666])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "try:\n", - " cfun(a, b)\n", - "except RuntimeError as e:\n", - " print(e)" + "cfun(a, b)" ] }, { "cell_type": "markdown", - "id": "3b1fd6e3", + "id": "d7cec09d", "metadata": {}, "source": [ - "There's no registered executor for `sincos` so we need to register an executor for our new primitive. Let's do that." + "Let's check how our function is represented in the execution trace now (change to `thunder.last_traces(cfun)[0]` to see the trace before transformations)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a7ff30ef", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(a, b):\n", + " # a: \"cpu f32[1]\" \n", + " # b: \"cpu f32[1]\" \n", + " (res, cos) = sincos(a)\n", + " del a\n", + " result = torch.add(res, cos) # result: \"cpu f32[1]\"\n", + " # result = ltorch.add(res, cos, alpha=None) # result: \"cpu f32[1]\"\n", + " # result = prims.add(res, cos) # result: \"cpu f32[1]\"\n", + " del res, cos\n", + " t3 = torch.add(result, b) # t3: \"cpu f32[1]\"\n", + " # t3 = ltorch.add(result, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(result, b) # t3: \"cpu f32[1]\"\n", + " del result, b\n", + " return t3" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "thunder.last_traces(cfun)[-1]" ] }, { "cell_type": "markdown", - "id": "026680b3-7b46-4f4b-b16b-641fa9bdcdf4", + "id": "35b71375", "metadata": {}, "source": [ - "Check out the \"adding-operator-executor.ipynb\" notebook to see how to implement an executor for a Symbol." + "For a peek under the hood, we can also first create a new symbol (without reference to an executor) and then register an executor for that.\n" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "2460f808-eacb-4a0f-8f62-6a17e3dce6e8", + "execution_count": 14, + "id": "f28094bb", "metadata": {}, "outputs": [], "source": [ - "from thunder.executors import add_operator_executor\n", - "\n", - "@log\n", - "def checker_sincos(a):\n", - " # We allow the sincos function to be called with any tensor\n", - " return True\n", - "\n", + "from thunder.core.symbol import Symbol\n", "@log\n", - "def executor_sincos(a):\n", - " return torch.sin(a), torch.cos(a)\n", + "def sincos_meta(input):\n", + " return (TensorProxy(like=input), TensorProxy(like=input))\n", "\n", - "op_map = {\n", - " CustomOps.sincos: (\"sincos\", checker_sincos, executor_sincos)\n", - "}\n", + "# this gives a nice, unique, printable id\n", + "class CustomOps(Enum):\n", + " sincos2 = 0\n", "\n", - "add_operator_executor(\"sincos_executor\", op_map, add_to_default_executors=True)" + "sincos2 = Symbol(\n", + " id=CustomOps.sincos2,\n", + " name=\"sincos2\",\n", + " meta=sincos_meta,\n", + " is_prim=True,\n", + ")" ] }, { "cell_type": "code", - "execution_count": 13, - "id": "d864fa05", + "execution_count": 15, + "id": "7fbab758", "metadata": {}, "outputs": [], "source": [ - "# Let's try again\n", - "cfun = thunder.compile(fun, disable_preprocessing=True)" + "def fun2(a, b):\n", + " sin, cos = sincos2(a)\n", + " return sin + cos + b\n", + "\n", + "cfun2 = thunder.jit(fun2)" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "24af4b99", + "execution_count": 16, + "id": "950d74ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cuda:0))\n", + "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "call checker_sincos(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- checker_sincos = True\n", + "Failed to find an executor for bound symbol bsym=(res, cos) = __main__.sincos2(a)\n" + ] + } + ], + "source": [ + "try:\n", + " cfun2(a, b)\n", + "except RuntimeError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "aadcf2a9", + "metadata": {}, + "source": [ + "There's no registered executor for `sincos` so we need to register an executor for our new primitive. Let's do that." + ] + }, + { + "cell_type": "markdown", + "id": "995febba", + "metadata": {}, + "source": [ + "Check out the \"adding-operator-executor.ipynb\" notebook to see how to implement an executor for a Symbol." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "956a4a6b", + "metadata": {}, + "outputs": [], + "source": [ + "@log\n", + "def checker_sincos2(a):\n", + " # We allow the sincos function to be called with any tensor\n", + " return True\n", + "\n", + "@log\n", + "def executor_sincos2(a):\n", + " # we need to have something here works with TensorProxies during the transformations,\n", + " # so we need to functions from thunder.torch or thunder.clang or other Symbols \n", + " return thunder.torch.sin(a), thunder.torch.cos(a)\n", + "\n", + "sincos_executor.register_implementation(sincos2, checker=checker_sincos2, execution_transform=executor_sincos2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "1c77c508", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", + "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", - "call checker_sincos(TensorProxy(name=a, shape=(1,), dtype=float32, device=cuda:0))\n", - "|<- checker_sincos = True\n", + "call checker_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", + "|<- checker_sincos2 = True\n", "\n", - "call executor_sincos(Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([-0.6296], device='cuda:0'))\n", - "|<- executor_sincos = (Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([-0.5889], device='cuda:0'), Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([0.8082], device='cuda:0'))\n", + "call executor_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", + "|<- executor_sincos2 = (TensorProxy(name=t4, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t5, shape=(1,), dtype=float32, device=cpu))\n", "\n" ] }, { "data": { "text/plain": [ - "tensor([0.1889], device='cuda:0')" + "tensor([0.7666])" ] }, - "execution_count": 14, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "cfun(a, b)" + "# Let's try again\n", + "cfun2 = thunder.jit(fun2)\n", + "cfun2(a, b)" ] }, { "cell_type": "code", - "execution_count": 15, - "id": "a7ff30ef", + "execution_count": 19, + "id": "f9797cf2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "# Constructed by Delete Last Used\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[1]\" \n", - " # b: \"cuda:0 f32[1]\" \n", - " (t0, t1) = sincos(a)\n", - " del [a]\n", - " (t3,) = nvFusion0(b, t0, t1)\n", - " # t2 = prims.add(t0, t1) # t2: \"cuda:0 f32[1]\"\n", - " # t3 = prims.add(t2, b) # t3: \"cuda:0 f32[1]\"\n", - " del [b, t0, t1]\n", + "@no_autocast()\n", + "def computation(a, b):\n", + " # a: \"cpu f32[1]\" \n", + " # b: \"cpu f32[1]\" \n", + " res = torch.sin(a) # res: \"cpu f32[1]\"\n", + " # res = ltorch.sin(a) # res: \"cpu f32[1]\"\n", + " # res = prims.sin(a) # res: \"cpu f32[1]\"\n", + " cos = torch.cos(a) # cos: \"cpu f32[1]\"\n", + " # cos = ltorch.cos(a) # cos: \"cpu f32[1]\"\n", + " # cos = prims.cos(a) # cos: \"cpu f32[1]\"\n", + " del a\n", + " result = torch.add(res, cos) # result: \"cpu f32[1]\"\n", + " # result = ltorch.add(res, cos, alpha=None) # result: \"cpu f32[1]\"\n", + " # result = prims.add(res, cos) # result: \"cpu f32[1]\"\n", + " del res, cos\n", + " t3 = torch.add(result, b) # t3: \"cpu f32[1]\"\n", + " # t3 = ltorch.add(result, b, alpha=None) # t3: \"cpu f32[1]\"\n", + " # t3 = prims.add(result, b) # t3: \"cpu f32[1]\"\n", + " del result, b\n", " return t3" ] }, - "execution_count": 15, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's check how our function is represented in the execution trace now\n", - "thunder.last_traces(cfun)[-1]" + "thunder.last_traces(cfun2)[-1]" ] }, { @@ -550,7 +657,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/notebooks/adding_custom_operator_backward.ipynb b/notebooks/adding_custom_operator_backward.ipynb index 2ed02db162..dc44b8c3d7 100644 --- a/notebooks/adding_custom_operator_backward.ipynb +++ b/notebooks/adding_custom_operator_backward.ipynb @@ -370,11 +370,11 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def computation(a, target):\n", - " # a: \"cuda:0 f32[2048, 50257]\" \n", - " # target: \"cuda:0 i64[2048]\" \n", - " (res, _) = apex_xentropy_forward(a, target, None, None, -100, None, 'none', 0.0)\n", - " del a, target\n", + "def computation(logits, labels):\n", + " # logits: \"cuda:0 f32[2048, 50257]\" \n", + " # labels: \"cuda:0 i64[2048]\" \n", + " (res, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", + " del logits, labels\n", " return res" ] }, @@ -626,188 +626,31 @@ }, { "cell_type": "markdown", - "id": "39fd6fce", + "id": "b4ec7c57", "metadata": {}, "source": [ - "With this, we can use the `grad` transform to get the gradient:" + "With these registrations, we can compile a function and it will be automatically transformed into forward and backward and wrapped in a PyTorch autograd.Function calling the backward trace computed by Thunder.\n" ] }, { "cell_type": "code", "execution_count": 12, - "id": "d9f6dfde", + "id": "8c5da6f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "call apex_cross_entropy_grad(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", - " call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", + "call apex_cross_entropy_grad(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])\n", + " call apex_xentropy_forward_meta(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])\n", " |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", - " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), 0.0)\n", + " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=label_smoothing, value=0.0])\n", " |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)\n", "\n", "|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)\n", "\n", - "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.7825, -1.1014, -0.9563, ..., 0.2801, 0.5359, -1.4094],\n", - " [ 1.1592, 0.8128, 0.5846, ..., 1.0255, 0.4217, 0.2548],\n", - " [ 0.8622, 0.5320, -1.5205, ..., -1.4938, -1.0423, -0.9527],\n", - " ...,\n", - " [-0.8978, 2.1914, 0.1603, ..., 0.0704, -0.7642, 1.4002],\n", - " [ 0.1750, 0.6244, 1.1711, ..., 0.3491, -0.5760, -1.4034],\n", - " [ 1.3689, -1.5422, 0.8149, ..., 0.9625, 1.0281, 1.4206]],\n", - " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([43812, 33387, 31729, ..., 27740, 2907, 8268], device='cuda:0'), None, None, -100, None, none, 0.0)\n", - "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.8060, 10.7141, 11.4505, ..., 11.2361, 10.6558, 11.2219],\n", - " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3327, 11.3294, 11.3304, ..., 11.3231, 11.3170, 11.3209],\n", - " device='cuda:0'))\n", - "\n", - "call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.7825, -1.1014, -0.9563, ..., 0.2801, 0.5359, -1.4094],\n", - " [ 1.1592, 0.8128, 0.5846, ..., 1.0255, 0.4217, 0.2548],\n", - " [ 0.8622, 0.5320, -1.5205, ..., -1.4938, -1.0423, -0.9527],\n", - " ...,\n", - " [-0.8978, 2.1914, 0.1603, ..., 0.0704, -0.7642, 1.4002],\n", - " [ 0.1750, 0.6244, 1.1711, ..., 0.3491, -0.5760, -1.4034],\n", - " [ 1.3689, -1.5422, 0.8149, ..., 0.9625, 1.0281, 1.4206]],\n", - " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([43812, 33387, 31729, ..., 27740, 2907, 8268], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3327, 11.3294, 11.3304, ..., 11.3231, 11.3170, 11.3209],\n", - " device='cuda:0'), 0.0)\n", - "|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[2.6187e-05, 3.9804e-06, 4.6022e-06, ..., 1.5845e-05, 2.0466e-05,\n", - " 2.9255e-06],\n", - " [3.8294e-05, 2.7081e-05, 2.1557e-05, ..., 3.3501e-05, 1.8315e-05,\n", - " 1.5500e-05],\n", - " [2.8425e-05, 2.0432e-05, 2.6236e-06, ..., 2.6946e-06, 4.2325e-06,\n", - " 4.6290e-06],\n", - " ...,\n", - " [4.9265e-06, 1.0818e-04, 1.4192e-05, ..., 1.2971e-05, 5.6303e-06,\n", - " 4.9039e-05],\n", - " [1.4491e-05, 2.2712e-05, 3.9235e-05, ..., 1.7247e-05, 6.8383e-06,\n", - " 2.9895e-06],\n", - " [4.7630e-05, 2.5919e-06, 2.7372e-05, ..., 3.1723e-05, 3.3876e-05,\n", - " 5.0159e-05]], device='cuda:0')\n", - "\n", - "Difference: 1.3969838619232178e-09\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def computation(logits, labels):\n", - " # logits: \"cuda:0 f32[2048, 50257]\" \n", - " # labels: \"cuda:0 i64[2048]\" \n", - " (_, t0) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", - " t4 = torch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", - " # t4 = ltorch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", - " # t4 = prims.full((2048,), 1.0, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[2048]\"\n", - " t3 = apex_xentropy_backward(t4, logits, labels, t0, 0.0) # t3: \"cuda:0 f32[2048, 50257]\"\n", - " del t4, logits, labels, t0\n", - " return [t3]\n" - ] - } - ], - "source": [ - "logits = torch.randn([2048, 50257], device=\"cuda\", requires_grad=True)\n", - "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", - "\n", - "grad_jfn = thunder.core.transforms.grad(jfn)\n", - "actual_grad, = grad_jfn(logits, labels)\n", - "\n", - "expected_grad, = torch.autograd.grad(loss_fn(logits, labels).sum(), logits)\n", - "\n", - "\n", - "print(\"Difference:\", (actual_grad - expected_grad).abs().max().item())\n", - "print(thunder.last_traces(grad_jfn)[-1])\n", - " \n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "f9f85e3b", - "metadata": {}, - "source": [ - "But life isn't completely simple. When we noticed that we thought about how to do backward for a long time, this is our previous approach, that is (in March 2024) needed for getting PyTorch Autograd integration.\n", - "This works by having a _forward rule_ for generating a tuple of result and values saved for backward and a _backward rule_ that takes the saved values and output grad to compute the input grads, much like PyTorch autograd itself, but with the pluggable executor architecture of Thunder.\n", - "\n", - "We are working at allowing you to skip this part!" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "b0379bc4", - "metadata": {}, - "outputs": [], - "source": [ - "from thunder.core.transforms import register_augmented_forward_with_checker, register_backward\n", - "\n", - "def apex_xentropy_forward_rule(\n", - " a,\n", - " target,\n", - " weight=None,\n", - " size_average=None,\n", - " ignore_index=-100,\n", - " reduce=None,\n", - " reduction=\"mean\",\n", - " label_smoothing=0.0,\n", - "):\n", - " loss, max_log_sum_exp = apex_xentropy_forward(\n", - " a,\n", - " target,\n", - " weight,\n", - " size_average,\n", - " ignore_index,\n", - " reduce,\n", - " reduction,\n", - " label_smoothing,\n", - " )\n", - " primal = loss\n", - " saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing)\n", - " return primal, saved_for_backward\n", - "\n", - "register_augmented_forward_with_checker(\n", - " apex_xentropy_ex,\n", - " \"torch.nn.functional.cross_entropy\",\n", - " apex_xentropy_checker,\n", - " apex_xentropy_forward_rule,\n", - ")\n", - "\n", - "@register_backward((apex_xentropy_ex, thunder.torch.cross_entropy.id))\n", - "def apex_cross_entropy_backward_rule(\n", - " logits, labels, max_log_sum_exp, reduction, smoothing, grad\n", - "):\n", - " if reduction != \"none\":\n", - " raise ValueError(f\"Invalid reduction: {reduction}\")\n", - "\n", - " grad_logits = apex_xentropy_backward(\n", - " grad,\n", - " logits,\n", - " labels,\n", - " max_log_sum_exp,\n", - " smoothing,\n", - " )\n", - " return grad_logits, *([None] * 7)" - ] - }, - { - "cell_type": "markdown", - "id": "b4ec7c57", - "metadata": {}, - "source": [ - "With these registrations, we can compile a function and it will be automatically transformed into forward and backward and wrapped in a PyTorch autograd.Function calling the backward trace computed by Thunder.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8c5da6f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)\n", "|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", @@ -861,7 +704,7 @@ " -4.9872e-05, -6.3328e-05]], device='cuda:0')\n", "\n", "Max error in loss: 9.5367431640625e-07\n", - "Max error in logits grad: 1.3969838619232178e-09\n" + "Max error in logits grad: 2.384185791015625e-07\n" ] }, { @@ -966,13 +809,12 @@ " return (t3, None)]" ] }, - "execution_count": 14, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from thunder.core.transforms import value_and_grad\n", "from thunder import torch as ltorch\n", "\n", "torch.manual_seed(0)\n", @@ -988,10 +830,10 @@ "actual_loss = cfn(logits, labels)\n", "go = torch.randn_like(actual_loss)\n", "\n", - "actual_grads, = torch.autograd.grad(actual_loss, logits, go)\n", + "actual_grad, = torch.autograd.grad(actual_loss, logits, go)\n", "\n", "expected_loss = loss_fn(logits, labels)\n", - "expected_grads, = torch.autograd.grad(expected_loss, logits, go)\n", + "expected_grad, = torch.autograd.grad(expected_loss, logits, go)\n", "\n", "print(\"Max error in loss:\", (actual_loss - expected_loss).abs().max().item())\n", "print(\"Max error in logits grad:\", (actual_grad - expected_grad).abs().max().item())\n", @@ -999,6 +841,102 @@ "thunder.last_traces(cfn)[-1]" ] }, + { + "cell_type": "markdown", + "id": "54d6a5ea", + "metadata": {}, + "source": [ + "Alternatively, we can also use the `grad` transform to get the gradient:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c88118eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "call apex_cross_entropy_grad(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", + " call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", + " |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", + "\n", + " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), 0.0)\n", + " |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)\n", + "\n", + "|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)\n", + "\n", + "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],\n", + " [ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],\n", + " [ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],\n", + " ...,\n", + " [-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],\n", + " [-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],\n", + " [ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],\n", + " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), None, None, -100, None, none, 0.0)\n", + "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.0233, 11.9095, 11.2898, ..., 10.9289, 10.7487, 10.7455],\n", + " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],\n", + " device='cuda:0'))\n", + "\n", + "call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],\n", + " [ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],\n", + " [ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],\n", + " ...,\n", + " [-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],\n", + " [-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],\n", + " [ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],\n", + " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],\n", + " device='cuda:0'), 0.0)\n", + "|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[2.0706e-05, 1.4403e-05, 4.1058e-06, ..., 1.4309e-05, 5.3827e-06,\n", + " 6.0079e-06],\n", + " [1.0461e-04, 4.8840e-05, 1.7949e-05, ..., 2.9621e-05, 7.3879e-06,\n", + " 3.8697e-05],\n", + " [2.3705e-05, 3.6822e-05, 2.5485e-05, ..., 1.0948e-05, 2.7806e-05,\n", + " 1.4513e-05],\n", + " ...,\n", + " [2.4836e-06, 1.0338e-05, 3.1331e-06, ..., 2.2417e-05, 1.0857e-05,\n", + " 1.8259e-05],\n", + " [7.0235e-06, 2.1758e-05, 1.3145e-05, ..., 7.3762e-06, 7.1699e-06,\n", + " 1.3360e-05],\n", + " [1.6078e-05, 4.2941e-06, 6.9897e-06, ..., 1.2304e-05, 4.6857e-05,\n", + " 4.1070e-06]], device='cuda:0')\n", + "\n", + "Difference: 1.3969838619232178e-09\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(logits, labels):\n", + " # logits: \"cuda:0 f32[2048, 50257]\" \n", + " # labels: \"cuda:0 i64[2048]\" \n", + " (_, t0) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", + " t4 = torch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", + " # t4 = ltorch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", + " # t4 = prims.full((2048,), 1.0, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[2048]\"\n", + " t3 = apex_xentropy_backward(t4, logits, labels, t0, 0.0) # t3: \"cuda:0 f32[2048, 50257]\"\n", + " del t4, logits, labels, t0\n", + " return [t3]\n" + ] + } + ], + "source": [ + "logits = torch.randn([2048, 50257], device=\"cuda\", requires_grad=True)\n", + "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", + "\n", + "grad_jfn = thunder.core.transforms.grad(jfn)\n", + "actual_grad, = grad_jfn(logits, labels)\n", + "\n", + "expected_grad, = torch.autograd.grad(loss_fn(logits, labels).sum(), logits)\n", + "\n", + "\n", + "print(\"Difference:\", (actual_grad - expected_grad).abs().max().item())\n", + "print(thunder.last_traces(grad_jfn)[-1])\n" + ] + }, { "cell_type": "markdown", "id": "e234a47b", @@ -1008,8 +946,7 @@ "\n", "- We defined a custom executor with custom operations (Symbols in Thunder language), each with a *Meta-* (data propagation) *function* and an implementation.\n", "- We defined and registered rules to map existing operations to our new operations. This allows us to use optimizations on our model without changing the model's code! \n", - "- We defined a gradient rule and saw how we the `grad` transform uses it.\n", - "- We saw another (older) way to implement forward and backward rules that is currently needed to get automatic integration with PyTorch's autograd.\n", + "- We defined a gradient rule and saw how our automatic PyTorch Autograd integration or the explicit `grad` transform uses it.\n", "\n", "Now go and implement your favourite optimized operators. We would love to hear about your use-cases!\n" ] diff --git a/notebooks/adding_operator_executor.ipynb b/notebooks/adding_operator_executor.ipynb deleted file mode 100644 index eac5685d3e..0000000000 --- a/notebooks/adding_operator_executor.ipynb +++ /dev/null @@ -1,688 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "b6f1f42d-f146-4c9c-8ed8-74f2bcf153f0", - "metadata": {}, - "source": [ - "# Adding an operator executor\n", - "\n", - "We are going to write a simple executor for `prims.add` function that calls NumPy's addition function. Our executor will be restricted to only work with inputs with certain properties. We will use the `add_operator_executor` function to create our executor." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "576d267d-9cef-4414-a722-b2cef0665cce", - "metadata": {}, - "outputs": [], - "source": [ - "import thunder\n", - "import torch\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "02e16bf5", - "metadata": {}, - "outputs": [], - "source": [ - "#@title Helper functions (execute this cell)\n", - "import functools\n", - "\n", - "_indentation = 0\n", - "def _log(msg=None):\n", - " \"\"\"Print a message at current indentation.\"\"\"\n", - " if msg is not None:\n", - " print(\" \" * _indentation + msg)\n", - "\n", - "def _log_indent(msg=None):\n", - " \"\"\"Print a message and then indent the rest.\"\"\"\n", - " global _indentation\n", - " _log(msg)\n", - " _indentation = 2 + _indentation\n", - "\n", - "def _log_unindent(msg=None):\n", - " \"\"\"Unindent then print a message.\"\"\"\n", - " global _indentation\n", - " _indentation = _indentation - 2\n", - " _log(msg)\n", - " \n", - "def log(func):\n", - " \"\"\"A decorator for functions to log arguments and results.\"\"\"\n", - " name = func.__name__\n", - " def pp(v):\n", - " \"\"\"Print certain values more succinctly\"\"\"\n", - " vtype = str(type(v))\n", - " if isinstance(v, tuple):\n", - " return \"({})\".format(pp_values(v))\n", - " elif isinstance(v, thunder.core.proxies.TensorProxy):\n", - " return f\"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})\"\n", - " elif isinstance(v, torch.Tensor):\n", - " return f\"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}\"\n", - " else:\n", - " return str(v)\n", - " def pp_values(args):\n", - " return \", \".join([pp(arg) for arg in args])\n", - "\n", - " @functools.wraps(func)\n", - " def func_wrapper(*args):\n", - " _log_indent(\"call {}({})\".format(name, pp_values(args)))\n", - " res = func(*args)\n", - " _log_unindent(\"|<- {} = {}\\n\".format(name, pp(res)))\n", - " return res\n", - "\n", - " return func_wrapper" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "666fa494-21f5-4ed7-829e-f8648fddb13a", - "metadata": {}, - "outputs": [], - "source": [ - "# This is our test function\n", - "def fun(a, b):\n", - " return a + b * a" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "edbf395f-9549-4ae3-957c-ba34fc956b3f", - "metadata": {}, - "outputs": [], - "source": [ - "# This is our test input\n", - "a = torch.randn(2, 2, device=\"cuda\")\n", - "b = torch.randn(2, 1, device=\"cuda\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "f938dff7-bac6-4807-b79d-a16cb5c6d90c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "# import thunder as thunder\n", - "# import thunder.torch as ltorch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[2, 2]\" \n", - " # b: \"cuda:0 f32[2, 1]\" \n", - " t1 = ltorch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"\n", - " # t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " t2 = ltorch.add(a, t1, alpha=None) # t2: \"cuda:0 f32[2, 2]\"\n", - " # t2 = prims.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"\n", - " return t2\n" - ] - } - ], - "source": [ - "# Let's see first how this function is represented as a trace\n", - "trace = thunder.trace()(fun, a, b)\n", - "print(trace)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6eb4818b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: \"cuda:0 f32[2, 2]\" |\n", - "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: \"cuda:0 f32[2, 1]\" |\n", - "Bound symbol with id=torch.mul is represented in the trace as |t1 = ltorch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"\n", - " # t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"|\n", - " It has the following subsymbols:\n", - " id=PrimIDs.BROADCAST_IN_DIM |t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"|\n", - " id=PrimIDs.MUL |t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"|\n", - "Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(a, t1, alpha=None) # t2: \"cuda:0 f32[2, 2]\"\n", - " # t2 = prims.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"|\n", - " It has the following subsymbols:\n", - " id=PrimIDs.ADD |t2 = prims.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"|\n", - "Bound symbol with id=PrimIDs.RETURN is represented in the trace as |return t2|\n" - ] - } - ], - "source": [ - "# We can loop over the recorded operations that we call BoundSymbols\n", - "for bound_symbol in trace.bound_symbols:\n", - " print(f\"Bound symbol with id={bound_symbol.sym.id} is represented in the trace as |{bound_symbol}|\")\n", - " if bound_symbol.subsymbols:\n", - " print(\" It has the following subsymbols:\")\n", - " for subsymbol in bound_symbol.subsymbols:\n", - " print(f\" id={subsymbol.sym.id} |{subsymbol}|\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "41566de2-a60f-4c87-a3d6-58e6a89dc38b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Help on function add_operator_executor in module thunder.executors:\n", - "\n", - "add_operator_executor(name, op_map, *, add_to_default_executors: bool = True) -> None\n", - "\n" - ] - } - ], - "source": [ - "from thunder.executors import add_operator_executor\n", - "\n", - "help(add_operator_executor)" - ] - }, - { - "cell_type": "markdown", - "id": "026680b3-7b46-4f4b-b16b-641fa9bdcdf4", - "metadata": {}, - "source": [ - "The key argument here is `op_map`.\n", - "\n", - "`op_map` is a dictionary with the id of the operator we're providing executor for as a key and `(name, checker_fn, implementation_fn)` tuple as a value.\n", - "\n", - "* `name` is the name of our execution function that would be appearing in the execution trace.\n", - "* `checker_fn` accepts the same set of arguments as the operator itself but returns `True` or `False` to signal to the executor orchestrator whether this particular set of inputs is supported or not.\n", - "* `implementation_fn` accepts real PyTorch tensors and expected to return PyTorch tensors." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "e02aaf0d", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's define the addition function that can work only with NumPy's ndarrays\n", - "\n", - "@log\n", - "def add_numpy(a, b):\n", - " assert isinstance(a, np.ndarray), \"a must be a NumPy ndarray\"\n", - " assert isinstance(b, np.ndarray), \"b must be a NumPy ndarray\"\n", - " return np.add(a, b)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7ddbe12e", - "metadata": {}, - "outputs": [], - "source": [ - "# We also need conversion functions from PyTorch to NumPy and back\n", - "@log\n", - "def torch_to_numpy(tensors):\n", - " return tuple(t.detach().cpu().numpy() for t in tensors)\n", - "\n", - "@log\n", - "def numpy_to_torch(arrays, device):\n", - " return tuple(torch.from_numpy(arr).to(device) for arr in arrays)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "2460f808-eacb-4a0f-8f62-6a17e3dce6e8", - "metadata": {}, - "outputs": [], - "source": [ - "@log\n", - "def checker_add_numpy(a, b):\n", - " # Suppose we only support float32 dtype, 2D, and (2, N) shape\n", - " first_condition = a.dtype == b.dtype == thunder.dtypes.float32\n", - " second_condition = a.ndim == b.ndim == 2\n", - " third_condition = a.shape[0] == b.shape[0] == 2\n", - " return first_condition and second_condition and third_condition" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "3a61b05f", - "metadata": {}, - "outputs": [], - "source": [ - "@log\n", - "def executor_add_numpy(a, b):\n", - " np_a, np_b = torch_to_numpy((a, b))\n", - " np_res = add_numpy(np_a, np_b)\n", - " res, = numpy_to_torch((np_res,), a.device)\n", - " return res" - ] - }, - { - "cell_type": "markdown", - "id": "c502944e", - "metadata": {}, - "source": [ - "Now we have all the pieces to create our executor." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e7d3eb2f", - "metadata": {}, - "outputs": [], - "source": [ - "op_map = {\n", - " thunder.prims.PrimIDs.ADD: (\"add_numpy\", checker_add_numpy, executor_add_numpy)\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "11f71c82", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's send our operator map to `add_operator_executor` to register our executor under the name \"custom_add_executor\"\n", - "\n", - "add_operator_executor(\"custom_add_executor\", op_map, add_to_default_executors=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "d864fa05", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's test our executor\n", - "\n", - "cfun = thunder.compile(fun, executors_list=[\"custom_add_executor\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "24af4b99", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Could not find executor for bound symbol t1 = ltorch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " # t0 = prims.broadcast_in_dim(b, [2, 2], (0, 1)) # t0: \"cuda:0 f32[2, 2]\"\n", - " # t1 = prims.mul(t0, a) # t1: \"cuda:0 f32[2, 2]\"\n" - ] - } - ], - "source": [ - "try:\n", - " cfun(a, b)\n", - "except RuntimeError as e:\n", - " print(e)" - ] - }, - { - "cell_type": "markdown", - "id": "d74d0c97", - "metadata": {}, - "source": [ - "The above function errors out because we haven't provided an executor for `ltorch.mul` yet. Let's do that." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "f1ff48d0", - "metadata": {}, - "outputs": [], - "source": [ - "cfun = thunder.compile(fun, executors_list=[\"custom_add_executor\", thunder.executors.TORCH])" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b1527d5e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call checker_add_numpy(TensorProxy(name=a, shape=(2, 2), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2, 2), dtype=float32, device=cuda:0))\n", - "|<- checker_add_numpy = True\n", - "\n", - "call checker_add_numpy(TensorProxy(name=a, shape=(2, 2), dtype=float32, device=cuda:0), TensorProxy(name=t1, shape=(2, 2), dtype=float32, device=cuda:0))\n", - "|<- checker_add_numpy = True\n", - "\n", - "call executor_add_numpy(Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0'))\n", - " call torch_to_numpy((Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0')))\n", - " |<- torch_to_numpy = ([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - "\n", - " call add_numpy([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - " |<- add_numpy = [[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]\n", - "\n", - " call numpy_to_torch(([[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]), cuda:0)\n", - " |<- numpy_to_torch = (Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0'))\n", - "\n", - "|<- executor_add_numpy = Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0')\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0')" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cfun(a, b)" - ] - }, - { - "cell_type": "markdown", - "id": "c55b5ed6", - "metadata": {}, - "source": [ - "Our logging decorator shows us that the `checker_add_numpy` function got called twice with `TensorProxy` as arguments and both times the function returned `True`. This means that our executor is going to be used for this particular execution trace.\n", - "\n", - "Then we see that the `executor_add_numpy` function is called with regular PyTorch tensors as arguments and it returns a regular PyTorch tensor." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "a7ff30ef", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "# Constructed by Delete Last Used\n", - "# import torch as torch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[2, 2]\" \n", - " # b: \"cuda:0 f32[2, 1]\" \n", - " t1 = torch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " del [b]\n", - " t2 = add_numpy(a, t1) # t2: \"cuda:0 f32[2, 2]\"\n", - " del [a, t1]\n", - " return t2" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Let's check how our function is represented in the execution trace now\n", - "thunder.last_traces(cfun)[-1]" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "0868c882", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call executor_add_numpy(Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0'))\n", - " call torch_to_numpy((Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-0.6906, -0.9761],\n", - " [ 0.9819, -0.1328]], device='cuda:0'), Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-1.3271, -1.8759],\n", - " [-0.2897, 0.0392]], device='cuda:0')))\n", - " |<- torch_to_numpy = ([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - "\n", - " call add_numpy([[-0.6905969 -0.97613984]\n", - " [ 0.98193294 -0.13276565]], [[-1.3271405 -1.8758768 ]\n", - " [-0.28966585 0.03916528]])\n", - " |<- add_numpy = [[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]\n", - "\n", - " call numpy_to_torch(([[-2.0177374 -2.8520167 ]\n", - " [ 0.69226706 -0.09360038]]), cuda:0)\n", - " |<- numpy_to_torch = (Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0'))\n", - "\n", - "|<- executor_add_numpy = Tensor(shape=torch.Size([2, 2]), stride=(2, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-2.0177, -2.8520],\n", - " [ 0.6923, -0.0936]], device='cuda:0')\n", - "\n" - ] - } - ], - "source": [ - "# Let's test whether the result is correct\n", - "cfun_torch = thunder.compile(fun, executors_list=[thunder.executors.TORCH])\n", - "expected = cfun_torch(a, b)\n", - "actual = cfun(a, b)\n", - "torch.testing.assert_close(expected, actual) # Should not raise an exception" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "f978b2de", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[SampleInput args=(tensor([[-5.6039, 5.0201, -8.2948, -0.1738],\n", - " [ 8.4915, -2.8353, -7.4601, -4.3015],\n", - " [ 6.0777, -7.6420, 3.4135, 3.2371],\n", - " [-0.8413, -1.7334, -1.0025, -0.7366]], device='cuda:0'), tensor([[ 4.5391, 1.5542, 7.9208, -1.3760],\n", - " [-6.5864, 8.6491, 6.1823, -1.8481],\n", - " [ 7.9385, -0.4884, 4.2281, 1.3158],\n", - " [-4.6107, 3.5805, 3.1749, -4.5989]], device='cuda:0')) kwargs={}]\n" - ] - } - ], - "source": [ - "from thunder.tests.opinfos import add_opinfo\n", - "\n", - "sample = next(add_opinfo.sample_input_generator(add_opinfo, device=\"cuda\", dtype=torch.float32, requires_grad=False))\n", - "print(sample)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "f07882f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call checker_add_numpy(TensorProxy(name=a, shape=(4, 4), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(4, 4), dtype=float32, device=cuda:0))\n", - "|<- checker_add_numpy = False\n", - "\n" - ] - } - ], - "source": [ - "# Let's test whether the result is correct\n", - "expected = cfun_torch(*sample.args)\n", - "actual = cfun(*sample.args)\n", - "torch.testing.assert_close(expected, actual) # Should not raise an exception" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "057689f5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "# Constructed by Delete Last Used\n", - "# import torch as torch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f32[2, 2]\" \n", - " # b: \"cuda:0 f32[2, 1]\" \n", - " t1 = torch.mul(b, a) # t1: \"cuda:0 f32[2, 2]\"\n", - " del [b]\n", - " t2 = torch.add(a, t1) # t2: \"cuda:0 f32[2, 2]\"\n", - " del [a, t1]\n", - " return t2" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# The order of executors matters today\n", - "cfun_torch_first = thunder.jit(fun, executors=[thunder.executors.TORCH, \"custom_add_executor\"])\n", - "cfun_torch_first(a, b)\n", - "thunder.last_traces(cfun_torch_first)[-1]" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "2aca9db7", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's try inputs that are not supported by our executor\n", - "a = torch.randn(3, 2, device=\"cuda\", dtype=torch.float64)\n", - "b = torch.randn(3, 1, device=\"cuda\", dtype=torch.float64)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "4b3e1589", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call checker_add_numpy(TensorProxy(name=a, shape=(3, 2), dtype=float64, device=cuda:0), TensorProxy(name=t1, shape=(3, 2), dtype=float64, device=cuda:0))\n", - "|<- checker_add_numpy = False\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "# Constructed by Delete Last Used\n", - "# import torch as torch\n", - "import torch\n", - "\n", - "@torch.no_grad()\n", - "def fun(a, b):\n", - " # a: \"cuda:0 f64[3, 2]\" \n", - " # b: \"cuda:0 f64[3, 1]\" \n", - " t1 = torch.mul(b, a) # t1: \"cuda:0 f64[3, 2]\"\n", - " del [b]\n", - " t2 = torch.add(a, t1) # t2: \"cuda:0 f64[3, 2]\"\n", - " del [a, t1]\n", - " return t2" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Let's see how our function is represented in the execution trace now with the new unsupported inputs\n", - "cfun(a, b)\n", - "thunder.last_traces(cfun)[-1]" - ] - }, - { - "cell_type": "markdown", - "id": "122ead11", - "metadata": {}, - "source": [ - "That's it! We've created our first executor. The process is very similar for other existing operators. There are two ingridients that are required to create an executor:\n", - "* `checker_fn` that checks whether the executor is applicable for a particular set of inputs (works with `TensorProxy` objects),\n", - "* `implementation_fn` that implements the operator (works with regular PyTorch tensors)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aec61cf6", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/dev_tutorials/extend.ipynb b/notebooks/dev_tutorials/extend.ipynb index 8b2fbe9036..304f6d9f6f 100644 --- a/notebooks/dev_tutorials/extend.ipynb +++ b/notebooks/dev_tutorials/extend.ipynb @@ -70,7 +70,7 @@ "source": [ "# Our operator executor will use the \"multimul\" function as a new example operator.\n", "# This function uses NumPy to perform two multiplications of four inputs.\n", - "# This functions very contrived, but will be useful to illustrate the extend submodule's capabilities.\n", + "# This function's contrived, but will be useful to illustrate the extend submodule's capabilities.\n", "def multimul_impl(\n", " a: Number | torch.Tensor, \n", " b: Number | torch.Tensor,\n", diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb new file mode 100644 index 0000000000..a4f61b47c3 --- /dev/null +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -0,0 +1,2025 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FSDP Tutorial\n", + "\n", + "In this tutorial, we will walk through the implementation of Fully Sharded Data Parallel (FSDP) with Zero2 sharding strategy in `thunder`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Introduction\n", + "\n", + "In recent times, the LLM models have grown so large that all the model parameters don't fit on a single GPU. To circumvent this problem, there are various strategies like Tensor Parallel, Pipeline Parallel, Fully Sharded Data Parallel, etc to train these large models. In this tutorial, we discuss and implement Zero2 strategy for Fully Sharded Data Parallel (FSDP).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### What is Zero2 strategy for FSDP?\n", + "\n", + "In this strategy, we shard the model parameters across all the availabe GPUs. That is each GPU holds onto only a chunk of the parameter. During the forward pass, all GPUs call `all_gather` communication primitive to gather the parameters from other GPUs. Unlike Zero3 strategy which frees the parameter after forward pass, we save these unsharded parameters for backward pass. This is to save the overhead of extra communication. In the backward pass, we utilize the saved parameters and compute the gradients. Once the gradients are computed, we use `reduce_scatter` communication primitive to reduce (average) the gradients across all GPUs and scatter those gradients so that a given GPU holds only a chunk of gradient.\n", + "\n", + "For more information on FSDP, we recommend reading\n", + "\n", + "1. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel - [Link](https://arxiv.org/abs/2304.11277)\n", + "2. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models - [Link](https://arxiv.org/abs/1910.02054)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Example Model\n", + "\n", + "For this example we will have a simple model `Linear(Tanh(Linear(x)))` which will be sharded over 2 GPUs\n", + "\n", + "**NOTE**: We are generating the abstract trace so we don't actually need a system with 2 GPUs for this. It is only required when we execute this trace." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.distributed\n", + "import thunder\n", + "import thunder.distributed\n", + "from IPython.display import Code" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device='cuda'\n", + "dim = 64\n", + "def create_model():\n", + " layers = [torch.nn.Linear(dim, dim, bias=False),\n", + " torch.nn.Tanh(),\n", + " torch.nn.Linear(dim, dim, bias=False)]\n", + " return torch.nn.Sequential(*layers).to(device)\n", + "\n", + "# Model\n", + "model = create_model()\n", + "# Input\n", + "x = torch.randn(dim, dim, device=device)\n", + "\n", + "\n", + "# we want to obtain a functional version of our model. The JIT does that internally and we reach into those\n", + "# internals here\n", + "thunder_model = thunder.jit(model)\n", + "cache_rec, i_, _ = thunder.compile_data(thunder_model).get_computation_and_inputs(x)\n", + "computation_trace = cache_rec.computation_traces[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def wrap_as_highlighted_code(trace):\n", + " return Code(str(trace), language=\"python\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can show the functional version:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       "import torch\n",
+       "import torch.nn.functional\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def augmented_forward_fn(*args):\n",
+       "  # args: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  t2, \\\n",
+       "  = args\n",
+       "  t3 = torch.nn.functional.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
+       "    # t3 = ltorch.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
+       "      # t3 = prims.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  [t4] = nvFusion0(t3)\n",
+       "    # t4 = prims.tanh(t3)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = torch.nn.functional.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
+       "    # t5 = ltorch.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
+       "      # t5 = prims.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = prims.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{p}{[}\\PY{n}{t4}\\PY{p}{]} \\PY{o}{=} \\PY{n}{nvFusion0}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{)}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = prims.tanh(t3) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = prims.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t5}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t5}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t4}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = torch.nn.functional.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = prims.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " [t4] = nvFusion0(t3)\n", + " # t4 = prims.tanh(t3) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = torch.nn.functional.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = prims.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wrap_as_highlighted_code(computation_trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 1 : Configuration\n", + "\n", + "For our implementation of FSDP, we will generate the trace where we are sharding our model over 2 GPU" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# FSDP Config \n", + "# Usually these values are set in the environment by `torchrun` but for this example\n", + "# we will set them ourselves\n", + "world_size = 2 # We have two processes.\n", + "global_rank = 0 # Current process is the very first process." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 2: Function to shard parameters\n", + "\n", + "Next step is to write a function which will actually shard the parameters over 0-dim." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: We shard over 0th dimension of the param.\n", + "def shard_param(param: torch.Tensor, rank: int, world_size: int, name: str) -> None:\n", + " # We will keep it simple and error if param's 0th dim is not divisible by ``world_size``.\n", + " # Alternative is that we can pad our parameters so that they are divisible by `world_size`.\n", + " assert param.shape[0] % world_size == 0,(\n", + " f\"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})\"\n", + " f\" to be divisible by the world size ({world_size})\"\n", + " )\n", + " chunk_size = param.shape[0] // world_size\n", + "\n", + " # rank helps us determine which chunk of the parameter we will hold.\n", + " shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()\n", + " param.data = shard\n", + "\n", + "# Shard each parameter of the model\n", + "for param_name, param in model.named_parameters():\n", + " shard_param(param, global_rank, world_size, param_name)\n", + " # Mark the param to denote that it is sharded.\n", + " # This is required by the synchronization primitive we will use below.\n", + " param.ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential(\n", + " (0): Linear(in_features=64, out_features=64, bias=False)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=64, out_features=64, bias=False)\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Verify our model looks as expected\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Let us verify that we have actually sharded the parameters.\n", + "# Checking if the weight of 1st Linear layer is sharded over 0th dim.\n", + "assert model[0].weight.shape == (dim / world_size, dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 3: Add an operation to synchronize the parameters before calling the model.forward." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have to create a process group. This is needed because the synchronization primitive `synchronize` that we will use to gather and scatter our weights in forward and backward requires a process group." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a process group\n", + "options = torch.distributed.distributed_c10d.ProcessGroup.Options(backend=\"nccl\")\n", + "process_group = torch.distributed.distributed_c10d.ProcessGroup(torch.distributed.distributed_c10d.Store(),\n", + " global_rank, world_size, options)\n", + "torch.distributed.distributed_c10d.GroupMember.WORLD = process_group" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Because we are trying to play tricks with the traces and skip the part that inserts the synchronization automatically but also does the translation from PyTorch to thunder, we need to drop one layer of the trace to apply this manually.\n", + "(This is really hacky, don't try it at home!)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       "import thunder\n",
+       "import thunder.core.prims as prims\n",
+       "import thunder.torch as ltorch\n",
+       "import torch\n",
+       "import torch.nn.functional\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def augmented_forward_fn(*args):\n",
+       "  # args: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  t2, \\\n",
+       "  = args\n",
+       "  t3 = ltorch.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
+       "    # t3 = ltorch.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
+       "      # t3 = prims.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  t4 = prims.tanh(t3)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = ltorch.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
+       "    # t5 = ltorch.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
+       "      # t5 = prims.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = prims.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = prims.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t5}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t5}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t4}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import thunder\n", + "import thunder.core.prims as prims\n", + "import thunder.torch as ltorch\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = prims.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = prims.tanh(t3) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = prims.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", + " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "### DON'T TRY THIS AT HOME\n", + "computation_trace.bound_symbols[2].sym = cache_rec.computation_traces[0].bound_symbols[2].subsymbols[0].sym\n", + "if cache_rec.computation_traces[0].bound_symbols[3].subsymbols:\n", + " computation_trace.bound_symbols[3] = cache_rec.computation_traces[0].bound_symbols[3].subsymbols[0]\n", + "computation_trace.bound_symbols[4].sym = cache_rec.computation_traces[0].bound_symbols[4].subsymbols[0].sym\n", + "\n", + "wrap_as_highlighted_code(computation_trace)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# now we have a functional version of the model which\n", + "# takes as inputs the expected arguments and all the parameters.\n", + "functional_forward = computation_trace.python_callable()\n", + "\n", + "# This function creates a model with synchronization\n", + "# before calling the forward pass.\n", + "def model_with_syncs(x, *params):\n", + " # We call `prims.synchronize` on all the parameters.\n", + " # This is essentially calling `all_gather` so that we have the complete\n", + " # parameter before we actually to the forward computation.\n", + " unsharded_params = []\n", + " for param in params:\n", + " unsharded_params.append(thunder.distributed.prims.synchronize(param, process_group))\n", + "\n", + " return functional_forward(x, *unsharded_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us now see what the trace of our model looks like with all the synchronization.\n", + "\n", + "Two main observations regarding the below trace \n", + "1. We can observe the `prims.synchronize` that we inserted using `model_with_syncs`.\n", + "2. Output of the `prims.synchronize` have the shape of unsharded (original) parameter.\n", + "\n", + "With this, we have implemented the FSDP for the forward pass of our model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
+       "import thunder\n",
+       "import thunder.core.prims as prims\n",
+       "import thunder.distributed.prims\n",
+       "import thunder.torch as ltorch\n",
+       "import torch\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def model_with_syncs(x, *params):\n",
+       "  # x: "cuda:0 f32[64, 64]" \n",
+       "  # params: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  = params\n",
+       "  t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t2: "cuda:0 f32[64, 64]"\n",
+       "  t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  t4 = ltorch.linear(x, t2, None)  # t4: "cuda:0 f32[64, 64]"\n",
+       "    # t4 = prims.linear(x, t2, None)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = prims.tanh(t4)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  t6 = ltorch.linear(t5, t3, None)  # t6: "cuda:0 f32[64, 64]"\n",
+       "    # t6 = prims.linear(t5, t3, None)  # t6: "cuda:0 f32[64, 64]"\n",
+       "  return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t3, t5), ()))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{model\\PYZus{}with\\PYZus{}syncs}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{o}{*}\\PY{n}{params}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} x: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", + " \\PY{c+c1}{\\PYZsh{} params: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{params}\n", + " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = prims.linear(x, t2, None) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t6 = prims.linear(t5, t3, None) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t6}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t6}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", + "import thunder\n", + "import thunder.core.prims as prims\n", + "import thunder.distributed.prims\n", + "import thunder.torch as ltorch\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def model_with_syncs(x, *params):\n", + " # x: \"cuda:0 f32[64, 64]\" \n", + " # params: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " = params\n", + " t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0) # t2: \"cuda:0 f32[64, 64]\"\n", + " t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = ltorch.linear(x, t2, None) # t4: \"cuda:0 f32[64, 64]\"\n", + " # t4 = prims.linear(x, t2, None) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = prims.tanh(t4) # t5: \"cuda:0 f32[64, 64]\"\n", + " t6 = ltorch.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = prims.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", + " return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t3, t5), ()))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trace = thunder.trace()(model_with_syncs, x, *model.parameters())\n", + "\n", + "wrap_as_highlighted_code(trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For backward, we don't have to do anything because `thunder` already knows how to compute the backward of `prims.synchronize`. We can verify that by using the `value_and_grad` transform to generate the complete forward and backward trace together.\n", + "\n", + "Observations for the trace below:\n", + "1. `prims.synchronize` from previous trace is now decomposed into `prims.all_gather` and `prims.wait`. So, we can clearly see that we make a communication call to gather the parameter (which is asynchronous) and wait till we have the complete parameter.\n", + "2. At the end of the trace (after the forward and the backward computation), we see calls to `prims.reduce_scatter` and `prims.wait`. This takes care of reducing the gradients across all the GPUs and sharding them. One thing to note, for averaging gradients with low dynamic range dtype like `float16`, if we naively sum the gradients across GPUs before dividing by `world_size`, it can lead to overflows. So we scale the gradient tensor with `world_size`, before calling `reduce_scatter` with `sum` reduction to effectively average the gradients without overflow." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Dead Code Elimination (took 1 milliseconds)\n",
+       "import thunder\n",
+       "import thunder.core.devices as devices\n",
+       "import thunder.core.dtypes as dtypes\n",
+       "import thunder.core.prims as prims\n",
+       "import thunder.distributed.prims\n",
+       "import thunder.torch as ltorch\n",
+       "import torch\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def _value_and_grad(*args):\n",
+       "  # args: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  t2, \\\n",
+       "  = args\n",
+       "  t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "  t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "  t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "  t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "  t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "  p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p11: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  t12 = thunder.distributed.prims.wait(p11)  # t12: "cuda:0 f32[64, 64]"\n",
+       "  p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p13: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  t14 = thunder.distributed.prims.wait(p13)  # t14: "cuda:0 f32[64, 64]"\n",
+       "  t15 = prims.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "  t16 = prims.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "  t17 = prims.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "  t18 = prims.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
+       "  t19 = prims.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
+       "  t20 = prims.add(t5, t9)  # t20: "cuda:0 f32[64, 64]"\n",
+       "  t21 = ltorch.reshape(t18, -1, 64)  # t21: "cuda:0 f32[64, 64]"\n",
+       "    # t21 = prims.reshape(t18, (64, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "  t22 = ltorch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "    # t22 = prims.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "  t23 = ltorch.reshape(t18, -1, 64)  # t23: "cuda:0 f32[64, 64]"\n",
+       "    # t23 = prims.reshape(t18, (64, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "  t24 = prims.transpose(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "  t25 = ltorch.reshape(t16, -1, 64)  # t25: "cuda:0 f32[64, 64]"\n",
+       "    # t25 = prims.reshape(t16, (64, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "  t26 = ltorch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "    # t26 = prims.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "  t27 = prims.add(t10, t22)  # t27: "cuda:0 f32[64, 64]"\n",
+       "  t28 = prims.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
+       "  t29 = ltorch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "    # t29 = prims.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "  t30 = ltorch.sub(1, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
+       "    # _ = prims.convert_element_type(1, float)\n",
+       "    # t30 = prims.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
+       "  t31 = ltorch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "    # t31 = prims.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "  t32 = ltorch.reshape(t31, -1, 64)  # t32: "cuda:0 f32[64, 64]"\n",
+       "    # t32 = prims.reshape(t31, (64, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "  t33 = ltorch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "    # t33 = prims.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "  t34 = ltorch.reshape(t31, -1, 64)  # t34: "cuda:0 f32[64, 64]"\n",
+       "    # t34 = prims.reshape(t31, (64, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "  t35 = prims.transpose(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "  t36 = ltorch.reshape(t0, -1, 64)  # t36: "cuda:0 f32[64, 64]"\n",
+       "    # t36 = prims.reshape(t0, (64, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "  t37 = ltorch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "    # t37 = prims.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "  t38 = prims.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
+       "  t39 = prims.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
+       "  t40 = ltorch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
+       "    # _ = prims.convert_element_type(2, float)\n",
+       "    # t40 = prims.div(t28, 2.0)  # t40: "cuda:0 f32[64, 64]"\n",
+       "  p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p41: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  t42 = thunder.distributed.prims.wait(p41)  # t42: "cuda:0 f32[32, 64]"\n",
+       "  t43 = ltorch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
+       "    # _ = prims.convert_element_type(2, float)\n",
+       "    # t43 = prims.div(t39, 2.0)  # t43: "cuda:0 f32[64, 64]"\n",
+       "  p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p44: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  t45 = thunder.distributed.prims.wait(p44)  # t45: "cuda:0 f32[32, 64]"\n",
+       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 1 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{devices} \\PY{k}{as} \\PY{n+nn}{devices}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{dtypes} \\PY{k}{as} \\PY{n+nn}{dtypes}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Dead Code Elimination (took 1 milliseconds)\n", + "import thunder\n", + "import thunder.core.devices as devices\n", + "import thunder.core.dtypes as dtypes\n", + "import thunder.core.prims as prims\n", + "import thunder.distributed.prims\n", + "import thunder.torch as ltorch\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def _value_and_grad(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n", + " t12 = thunder.distributed.prims.wait(p11) # t12: \"cuda:0 f32[64, 64]\"\n", + " p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n", + " t14 = thunder.distributed.prims.wait(p13) # t14: \"cuda:0 f32[64, 64]\"\n", + " t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", + " t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", + " t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", + " t21 = ltorch.reshape(t18, -1, 64) # t21: \"cuda:0 f32[64, 64]\"\n", + " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " t23 = ltorch.reshape(t18, -1, 64) # t23: \"cuda:0 f32[64, 64]\"\n", + " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " t25 = ltorch.reshape(t16, -1, 64) # t25: \"cuda:0 f32[64, 64]\"\n", + " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", + " t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(1, float)\n", + " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", + " t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " t32 = ltorch.reshape(t31, -1, 64) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " t34 = ltorch.reshape(t31, -1, 64) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " t36 = ltorch.reshape(t0, -1, 64) # t36: \"cuda:0 f32[64, 64]\"\n", + " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", + " t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", + " t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n", + " p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n", + " t42 = thunder.distributed.prims.wait(p41) # t42: \"cuda:0 f32[32, 64]\"\n", + " t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n", + " p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n", + " t45 = thunder.distributed.prims.wait(p44) # t45: \"cuda:0 f32[32, 64]\"\n", + " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from thunder.core.transforms import value_and_grad\n", + "\n", + "forward_and_backward_model = value_and_grad(model_with_syncs)\n", + "\n", + "forward_backward_trace = thunder.trace()(forward_and_backward_model, x, *model.parameters())\n", + "\n", + "wrap_as_highlighted_code(forward_backward_trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above trace, only contains primitive which specifies the semantic of an operation abstractly but doesn't perform the actual computation.\n", + "\n", + "Now we will generate the execution trace which can actually perform the compute.\n", + "\n", + "In the execution trace generated below, we can see that all the primitives have been replaced with actually PyTorch operations. Also, our synchronization primitives have been replaced with PyTorch implementation provided by thunder i.e. `torch_all_gather_prim_impl`, `torch_reduce_scatter_prim_impl`, `torch_wait_prim_impl`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
# Constructed by Delete Last Used (took 0 milliseconds)\n",
+       "import torch\n",
+       "import torch.nn.functional\n",
+       "from thunder.executors.torchex import no_autocast\n",
+       "\n",
+       "@torch.no_grad()\n",
+       "@no_autocast()\n",
+       "def _value_and_grad(*args):\n",
+       "  # args: "Collection" \n",
+       "  t0, \\\n",
+       "  t1, \\\n",
+       "  t2, \\\n",
+       "  = args\n",
+       "  del args\n",
+       "  t3 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "    # t3 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "      # t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3: "cuda:0 f32[64, 64]"\n",
+       "  t4 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "    # t4 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "      # t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[64, 64]"\n",
+       "  t5 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "    # t5 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "      # t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t5: "cuda:0 f32[64, 64]"\n",
+       "  t6 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "    # t6 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "      # t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6: "cuda:0 f32[64, 64]"\n",
+       "  t7 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "    # t7 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "      # t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t7: "cuda:0 f32[64, 64]"\n",
+       "  t8 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "    # t8 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "      # t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[64, 64]"\n",
+       "  t9 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "    # t9 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "      # t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t9: "cuda:0 f32[64, 64]"\n",
+       "  t10 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "    # t10 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "      # t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t10: "cuda:0 f32[64, 64]"\n",
+       "  p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p11: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  del t1\n",
+       "  t12 = torch_wait_prim_impl(p11)  # t12: "cuda:0 f32[64, 64]"\n",
+       "  del p11\n",
+       "  p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p13: "FUTURE cuda:0 f32[64, 64]"\n",
+       "  del t2\n",
+       "  t14 = torch_wait_prim_impl(p13)  # t14: "cuda:0 f32[64, 64]"\n",
+       "  del p13\n",
+       "  t15 = torch.nn.functional.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "    # t15 = ltorch.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "      # t15 = prims.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
+       "  t16 = torch.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "    # t16 = ltorch.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "      # t16 = prims.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
+       "  del t15\n",
+       "  t17 = torch.nn.functional.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "    # t17 = ltorch.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "      # t17 = prims.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
+       "  t18 = torch.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
+       "    # t18 = ltorch.add(t6, t7, alpha=None)  # t18: "cuda:0 f32[64, 64]"\n",
+       "      # t18 = prims.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
+       "  del t6, t7\n",
+       "  t19 = torch.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
+       "    # t19 = ltorch.add(t3, t8, alpha=None)  # t19: "cuda:0 f32[64, 64]"\n",
+       "      # t19 = prims.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
+       "  del t3, t8\n",
+       "  t20 = torch.add(t5, t9)  # t20: "cuda:0 f32[64, 64]"\n",
+       "    # t20 = ltorch.add(t5, t9, alpha=None)  # t20: "cuda:0 f32[64, 64]"\n",
+       "      # t20 = prims.add(t5, t9)  # t20: "cuda:0 f32[64, 64]"\n",
+       "  del t5, t9\n",
+       "  t21 = torch.reshape(t18, (-1, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "    # t21 = ltorch.reshape(t18, (-1, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "      # t21 = prims.reshape(t18, (64, 64))  # t21: "cuda:0 f32[64, 64]"\n",
+       "  t22 = torch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "    # t22 = ltorch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "      # t22 = prims.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
+       "  del t21\n",
+       "  t23 = torch.reshape(t18, (-1, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "    # t23 = ltorch.reshape(t18, (-1, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "      # t23 = prims.reshape(t18, (64, 64))  # t23: "cuda:0 f32[64, 64]"\n",
+       "  del t18\n",
+       "  t24 = torch.permute(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "    # t24 = ltorch.permute(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "      # t24 = prims.transpose(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
+       "  del t23\n",
+       "  t25 = torch.reshape(t16, (-1, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "    # t25 = ltorch.reshape(t16, (-1, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "      # t25 = prims.reshape(t16, (64, 64))  # t25: "cuda:0 f32[64, 64]"\n",
+       "  t26 = torch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "    # t26 = ltorch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "      # t26 = prims.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
+       "  del t24, t25\n",
+       "  t27 = torch.add(t10, t22)  # t27: "cuda:0 f32[64, 64]"\n",
+       "    # t27 = ltorch.add(t10, t22, alpha=None)  # t27: "cuda:0 f32[64, 64]"\n",
+       "      # t27 = prims.add(t10, t22)  # t27: "cuda:0 f32[64, 64]"\n",
+       "  del t10, t22\n",
+       "  t28 = torch.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
+       "    # t28 = ltorch.add(t20, t26, alpha=None)  # t28: "cuda:0 f32[64, 64]"\n",
+       "      # t28 = prims.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
+       "  del t20, t26\n",
+       "  t29 = torch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "    # t29 = ltorch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "      # t29 = prims.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
+       "  t30 = torch.sub(1, t29)  # t30: "cuda:0 f32[64, 64]"\n",
+       "    # t30 = ltorch.sub(1, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
+       "      # _ = prims.convert_element_type(1, float)\n",
+       "      # t30 = prims.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
+       "  del t29\n",
+       "  t31 = torch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "    # t31 = ltorch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "      # t31 = prims.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
+       "  del t27, t30\n",
+       "  t32 = torch.reshape(t31, (-1, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "    # t32 = ltorch.reshape(t31, (-1, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "      # t32 = prims.reshape(t31, (64, 64))  # t32: "cuda:0 f32[64, 64]"\n",
+       "  t33 = torch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "    # t33 = ltorch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "      # t33 = prims.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
+       "  del t32\n",
+       "  t34 = torch.reshape(t31, (-1, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "    # t34 = ltorch.reshape(t31, (-1, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "      # t34 = prims.reshape(t31, (64, 64))  # t34: "cuda:0 f32[64, 64]"\n",
+       "  del t31\n",
+       "  t35 = torch.permute(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "    # t35 = ltorch.permute(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "      # t35 = prims.transpose(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
+       "  del t34\n",
+       "  t36 = torch.reshape(t0, (-1, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "    # t36 = ltorch.reshape(t0, (-1, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "      # t36 = prims.reshape(t0, (64, 64))  # t36: "cuda:0 f32[64, 64]"\n",
+       "  t37 = torch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "    # t37 = ltorch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "      # t37 = prims.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
+       "  del t35, t36\n",
+       "  t38 = torch.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
+       "    # t38 = ltorch.add(t19, t33, alpha=None)  # t38: "cuda:0 f32[64, 64]"\n",
+       "      # t38 = prims.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
+       "  del t19, t33\n",
+       "  t39 = torch.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
+       "    # t39 = ltorch.add(t4, t37, alpha=None)  # t39: "cuda:0 f32[64, 64]"\n",
+       "      # t39 = prims.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
+       "  del t4, t37\n",
+       "  t40 = torch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
+       "    # t40 = ltorch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
+       "      # _ = prims.convert_element_type(2, float)\n",
+       "      # t40 = prims.div(t28, 2.0)  # t40: "cuda:0 f32[64, 64]"\n",
+       "  del t28\n",
+       "  p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p41: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  del t40\n",
+       "  t42 = torch_wait_prim_impl(p41)  # t42: "cuda:0 f32[32, 64]"\n",
+       "  del p41\n",
+       "  t43 = torch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
+       "    # t43 = ltorch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
+       "      # _ = prims.convert_element_type(2, float)\n",
+       "      # t43 = prims.div(t39, 2.0)  # t43: "cuda:0 f32[64, 64]"\n",
+       "  del t39\n",
+       "  p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p44: "FUTURE cuda:0 f32[32, 64]"\n",
+       "  del t43\n",
+       "  t45 = torch_wait_prim_impl(p44)  # t45: "cuda:0 f32[32, 64]"\n",
+       "  del p44\n",
+       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{c+c1}{\\PYZsh{} Constructed by Delete Last Used (took 0 milliseconds)}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", + "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", + "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", + "\n", + "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", + "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", + " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", + " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", + " \\PY{o}{=} \\PY{n}{args}\n", + " \\PY{k}{del} \\PY{n}{args}\n", + " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t3 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t4 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t5 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t6 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t6 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t7 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t7 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t8 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t8 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t9 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t9 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t10 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t10 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t1}\n", + " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p11}\n", + " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t2}\n", + " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p13}\n", + " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t15 = ltorch.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t15 = prims.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t16 = ltorch.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t16 = prims.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t15}\n", + " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t17 = ltorch.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t17 = prims.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t18 = ltorch.add(t6, t7, alpha=None) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t18 = prims.add(t6, t7) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\n", + " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t19 = ltorch.add(t3, t8, alpha=None) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t19 = prims.add(t3, t8) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\n", + " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.add(t5, t9, alpha=None) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t20 = prims.add(t5, t9) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\n", + " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t21 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t22 = ltorch.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t21}\n", + " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t23 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t18}\n", + " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t24 = ltorch.permute(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t24 = prims.transpose(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t23}\n", + " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t25 = ltorch.reshape(t16, (\\PYZhy{}1, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t26 = ltorch.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\n", + " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t27 = ltorch.add(t10, t22, alpha=None) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t27 = prims.add(t10, t22) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\n", + " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t28 = ltorch.add(t20, t26, alpha=None) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t28 = prims.add(t20, t26) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\n", + " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t29 = ltorch.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t30 = ltorch.sub(1, t29, alpha=None) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t29}\n", + " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t31 = ltorch.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\n", + " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t32 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t33 = ltorch.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t32}\n", + " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t34 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t31}\n", + " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t35 = ltorch.permute(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t35 = prims.transpose(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t34}\n", + " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t36 = ltorch.reshape(t0, (\\PYZhy{}1, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t37 = ltorch.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\n", + " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t38 = ltorch.add(t19, t33, alpha=None) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t38 = prims.add(t19, t33) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\n", + " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t39 = ltorch.add(t4, t37, alpha=None) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t39 = prims.add(t4, t37) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\n", + " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t40 = ltorch.true\\PYZus{}divide(t28, 2) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t28}\n", + " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t40}\n", + " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p41}\n", + " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} t43 = ltorch.true\\PYZus{}divide(t39, 2) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", + " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t39}\n", + " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{t43}\n", + " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", + " \\PY{k}{del} \\PY{n}{p44}\n", + " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def _value_and_grad(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " = args\n", + " del args\n", + " t3 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " # t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n", + " t4 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " # t4 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " # t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n", + " t5 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " # t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n", + " t6 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n", + " t7 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " # t7 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " # t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n", + " t8 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " # t8 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " # t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n", + " t9 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " # t9 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " # t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n", + " t10 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " # t10 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " # t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n", + " p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n", + " del t1\n", + " t12 = torch_wait_prim_impl(p11) # t12: \"cuda:0 f32[64, 64]\"\n", + " del p11\n", + " p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n", + " del t2\n", + " t14 = torch_wait_prim_impl(p13) # t14: \"cuda:0 f32[64, 64]\"\n", + " del p13\n", + " t15 = torch.nn.functional.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " # t15 = ltorch.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " # t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", + " t16 = torch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " # t16 = ltorch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " # t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", + " del t15\n", + " t17 = torch.nn.functional.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " # t17 = ltorch.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " # t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", + " t18 = torch.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", + " # t18 = ltorch.add(t6, t7, alpha=None) # t18: \"cuda:0 f32[64, 64]\"\n", + " # t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", + " del t6, t7\n", + " t19 = torch.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", + " # t19 = ltorch.add(t3, t8, alpha=None) # t19: \"cuda:0 f32[64, 64]\"\n", + " # t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", + " del t3, t8\n", + " t20 = torch.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", + " # t20 = ltorch.add(t5, t9, alpha=None) # t20: \"cuda:0 f32[64, 64]\"\n", + " # t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", + " del t5, t9\n", + " t21 = torch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " # t21 = ltorch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", + " t22 = torch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " # t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", + " del t21\n", + " t23 = torch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " # t23 = ltorch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", + " del t18\n", + " t24 = torch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " # t24 = ltorch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " # t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", + " del t23\n", + " t25 = torch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " # t25 = ltorch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", + " t26 = torch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " # t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", + " del t24, t25\n", + " t27 = torch.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " # t27 = ltorch.add(t10, t22, alpha=None) # t27: \"cuda:0 f32[64, 64]\"\n", + " # t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", + " del t10, t22\n", + " t28 = torch.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", + " # t28 = ltorch.add(t20, t26, alpha=None) # t28: \"cuda:0 f32[64, 64]\"\n", + " # t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", + " del t20, t26\n", + " t29 = torch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " # t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", + " t30 = torch.sub(1, t29) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(1, float)\n", + " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", + " del t29\n", + " t31 = torch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", + " del t27, t30\n", + " t32 = torch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = ltorch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", + " t33 = torch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", + " del t32\n", + " t34 = torch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = ltorch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", + " del t31\n", + " t35 = torch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " # t35 = ltorch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " # t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", + " del t34\n", + " t36 = torch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " # t36 = ltorch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", + " t37 = torch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " # t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", + " del t35, t36\n", + " t38 = torch.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", + " # t38 = ltorch.add(t19, t33, alpha=None) # t38: \"cuda:0 f32[64, 64]\"\n", + " # t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", + " del t19, t33\n", + " t39 = torch.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", + " # t39 = ltorch.add(t4, t37, alpha=None) # t39: \"cuda:0 f32[64, 64]\"\n", + " # t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", + " del t4, t37\n", + " t40 = torch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", + " # t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n", + " del t28\n", + " p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t40\n", + " t42 = torch_wait_prim_impl(p41) # t42: \"cuda:0 f32[32, 64]\"\n", + " del p41\n", + " t43 = torch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", + " # t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", + " # _ = prims.convert_element_type(2, float)\n", + " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n", + " del t39\n", + " p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t43\n", + " t45 = torch_wait_prim_impl(p44) # t45: \"cuda:0 f32[32, 64]\"\n", + " del p44\n", + " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimized_trace = thunder.transform_for_execution(forward_backward_trace, executors_list=thunder.get_always_executors())\n", + "\n", + "# Grab the final trace\n", + "exec_trace = optimized_trace[-1]\n", + "wrap_as_highlighted_code(exec_trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Step 4 : Running the actual computation\n", + "\n", + "Running the actual computation will require setting up 2 processes and running our above code in both those processes (which can be tricky with Jupyter Notebook). Instead, we will write a small script and run it with `torchrun` which takes care of setting up the processes and relevant state.\n", + "\n", + "**NOTE**: This requires device running this notebook to have at least 2-GPUs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the example below, we will use `thunder.distributed.fsdp` which does the same as what we did above (with some extra checks). The code below should look familiar as it is roughly all the above pieces in a single script. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting thunder_fsdp_simple_example.py\n" + ] + } + ], + "source": [ + "%%writefile thunder_fsdp_simple_example.py\n", + "\n", + "# imports\n", + "from thunder.tests.lit_gpt_model import GPT, Config\n", + "import torch\n", + "import torch.distributed\n", + "import thunder\n", + "import thunder.distributed\n", + "import os\n", + "\n", + "# # # # # # # #\n", + "# Create Model\n", + "# # # # # # # #\n", + "\n", + "# NOTE: We create the model on CPU.\n", + "device='cpu'\n", + "dim = 64\n", + "def create_model():\n", + " layers = []\n", + " layers.append(torch.nn.Linear(dim, dim))\n", + " layers.append(torch.nn.ReLU())\n", + " layers.append(torch.nn.Linear(dim, dim))\n", + " return torch.nn.Sequential(*layers).to(device)\n", + "\n", + "# Model\n", + "model = create_model()\n", + "# Input\n", + "x = torch.randn(dim, dim, device=device)\n", + "\n", + "# # # # # # # #\n", + "# Setup for distributed\n", + "# # # # # # # #\n", + "torch.distributed.init_process_group(backend='nccl')\n", + "\n", + "rank = int(os.environ[\"LOCAL_RANK\"])\n", + "\n", + "device = f\"cuda:{rank}\"\n", + "\n", + "# # # # # # # #\n", + "# Move inputs to correct device\n", + "# # # # # # # #\n", + "x = x.to(device)\n", + "\n", + "# # # # # # # #\n", + "# Wrap the model in thunder.distributed.fsdp\n", + "# # # # # # # #\n", + "\n", + "# thunder.distributed.fsdp takes care of moving the parameter\n", + "# shard to the correct GPU for the current process.\n", + "cmodel = thunder.jit(thunder.distributed.fsdp(model))\n", + "\n", + "# Run the forward pass.\n", + "cmodel(x)\n", + "\n", + "# # # # # # # #\n", + "# Check the traces\n", + "# # # # # # # #\n", + "fwd_traces = thunder.last_traces(cmodel)\n", + "bwd_traces = thunder.last_backward_traces(cmodel)\n", + "\n", + "# # # # # # # #\n", + "# Print and check to see if they match ours\n", + "# # # # # # # #\n", + "if rank == 0:\n", + " print(fwd_traces[-1])\n", + " print(\"*******\"* 8)\n", + " print(bwd_traces[-1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us run the above script and check what the trace looks like.\n", + "\n", + "We can observe that forward trace has `torch_all_gather_prim_impl` to gather the parameter before forward pass and the backward trace has `torch_reduce_scatter_prim_impl` to reduce and scatter the gradients back to different GPUs. This is similar to our implementation above." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] \n", + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n", + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(input, t_0_bias, t_2_bias, t_0_weight, t_2_weight):\n", + " # input: \"cuda:0 f32[64, 64]\" \n", + " # t_0_bias: \"cuda:0 f32[32]\" \n", + " p0 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p0: \"FUTURE cuda:0 f32[64]\"\n", + " # t_2_bias: \"cuda:0 f32[32]\" \n", + " p2 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p2: \"FUTURE cuda:0 f32[64]\"\n", + " # t_0_weight: \"cuda:0 f32[32, 64]\" \n", + " p4 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p4: \"FUTURE cuda:0 f32[64, 64]\"\n", + " # t_2_weight: \"cuda:0 f32[32, 64]\" \n", + " p9 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p9: \"FUTURE cuda:0 f32[64, 64]\"\n", + " t1 = torch_wait_prim_impl(p0) # t1: \"cuda:0 f32[64]\"\n", + " del p0\n", + " t3 = torch_wait_prim_impl(p2) # t3: \"cuda:0 f32[64]\"\n", + " del p2\n", + " t5 = torch_wait_prim_impl(p4) # t5: \"cuda:0 f32[64, 64]\"\n", + " del p4\n", + " t6 = torch.nn.functional.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = ltorch.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", + " # t6 = prims.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", + " del t5, t1\n", + " [t7, t8] = nvFusion0(t6)\n", + " # t7 = prims.gt(t6, 0.0) # t7: \"cuda:0 b8[64, 64]\"\n", + " # t8 = prims.where(t7, t6, 0.0) # t8: \"cuda:0 f32[64, 64]\"\n", + " del t6\n", + " t10 = torch_wait_prim_impl(p9) # t10: \"cuda:0 f32[64, 64]\"\n", + " del p9\n", + " t11 = torch.nn.functional.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", + " # t11 = ltorch.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", + " # t11 = prims.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", + " del t3\n", + " return {'output': t11, 'flat_args': [input, t_0_bias, t_2_bias, t_0_weight, t_2_weight], 'flat_output': (t11,)}, ((input, t10, t7, t8), ())\n", + "********************************************************\n", + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def backward_fn(saved_for_backward, cotangents):\n", + " # saved_for_backward: \"Collection\" \n", + " # cotangents: \"Collection\" \n", + " C0, \\\n", + " _, \\\n", + " = saved_for_backward\n", + " clear_collection(saved_for_backward)\n", + " del saved_for_backward\n", + " t0, \\\n", + " = cotangents\n", + " clear_collection(cotangents)\n", + " del cotangents\n", + " input, \\\n", + " t10, \\\n", + " t7, \\\n", + " t8, \\\n", + " = C0\n", + " clear_collection(C0)\n", + " del C0\n", + " t31 = torch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = ltorch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", + " # t31 = prims.reshape(t0, (64, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", + " t32 = torch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = ltorch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", + " # t32 = prims.transpose(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", + " t33 = torch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = ltorch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", + " # t33 = prims.reshape(t8, (64, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", + " del t8\n", + " t45 = torch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", + " # t45 = ltorch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", + " # t45 = prims.reshape(input, (64, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", + " del input\n", + " [t51] = nvFusion0(t0)\n", + " # t35 = prims.sum(t0, (0,)) # t35: \"cuda:0 f32[64]\"\n", + " # t51 = prims.div(t35, 2.0) # t51: \"cuda:0 f32[64]\"\n", + " del t0\n", + " p52 = torch_reduce_scatter_prim_impl(t51, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p52: \"FUTURE cuda:0 f32[32]\"\n", + " del t51\n", + " t30 = torch.matmul(t31, t10) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = ltorch.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n", + " # t30 = prims.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n", + " del t31, t10\n", + " t34 = torch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = ltorch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", + " # t34 = prims.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", + " del t32, t33\n", + " [t36, t39, t54] = nvFusion1(t30, t34, t7)\n", + " # t39 = prims.where(t7, t30, 0.0) # t39: \"cuda:0 f32[64, 64]\"\n", + " # t47 = prims.sum(t39, (0,)) # t47: \"cuda:0 f32[64]\"\n", + " # t54 = prims.div(t47, 2.0) # t54: \"cuda:0 f32[64]\"\n", + " # t36 = prims.div(t34, 2.0) # t36: \"cuda:0 f32[64, 64]\"\n", + " del t30, t34, t7\n", + " p37 = torch_reduce_scatter_prim_impl(t36, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p37: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t36\n", + " p55 = torch_reduce_scatter_prim_impl(t54, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p55: \"FUTURE cuda:0 f32[32]\"\n", + " del t54\n", + " t43 = torch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", + " # t43 = ltorch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", + " # t43 = prims.reshape(t39, (64, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", + " del t39\n", + " t44 = torch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", + " # t44 = ltorch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", + " # t44 = prims.transpose(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", + " del t43\n", + " t46 = torch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", + " # t46 = ltorch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", + " # t46 = prims.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", + " del t44, t45\n", + " [t48] = nvFusion2(t46)\n", + " # t48 = prims.div(t46, 2.0) # t48: \"cuda:0 f32[64, 64]\"\n", + " del t46\n", + " p49 = torch_reduce_scatter_prim_impl(t48, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p49: \"FUTURE cuda:0 f32[32, 64]\"\n", + " del t48\n", + " t53 = torch_wait_prim_impl(p52) # t53: \"cuda:0 f32[32]\"\n", + " del p52\n", + " t38 = torch_wait_prim_impl(p37) # t38: \"cuda:0 f32[32, 64]\"\n", + " del p37\n", + " t56 = torch_wait_prim_impl(p55) # t56: \"cuda:0 f32[32]\"\n", + " del p55\n", + " t50 = torch_wait_prim_impl(p49) # t50: \"cuda:0 f32[32, 64]\"\n", + " del p49\n", + " return (None, t56, t53, t50, t38)\n" + ] + } + ], + "source": [ + "!torchrun --nproc_per_node=2 thunder_fsdp_simple_example.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Conclusion\n", + "\n", + "We have created our implementation of FSDP to shard our model across multiple GPUs. In the process, we also learned that:\n", + "\n", + "1. `thunder` provides us with primitives for synchronization across mutiple GPUs.\n", + "2. `thunder` also takes care of implementing the backward support for the synchronization primitives, so we don't have to explicitly do anything to get the backward working.\n", + "3. We can just easily apply `thunder.distributed.fsdp` to our model and it will take care of sharding the parameters and also adding synchronizations to our model. Also, we can easily check the modifications by inspecting the traces." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/dev_tutorials/patterns.ipynb b/notebooks/dev_tutorials/patterns.ipynb deleted file mode 100644 index b22a8598fd..0000000000 --- a/notebooks/dev_tutorials/patterns.ipynb +++ /dev/null @@ -1,441 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Thunder pattern matching for transformations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This developer tutorial discusses patterns -- sequences of operations that can be matched and replaced with traceable functions. \n", - "# It's a work-in-progress, and it currently only discusses how patterns can be constructed and how they're matched,\n", - "# along with some related utilities." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# Imports the modules, classes, and functions we'll need for this tutorial\n", - "import torch\n", - "\n", - "import thunder\n", - "from thunder.core.patterns import Pattern, bind_names, numbered_ancestors\n", - "from thunder.core.proxies import TensorProxy\n", - "from thunder.core.symbol import BoundSymbol" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# To match a pattern, start by creating a Pattern object\n", - "p = Pattern()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Then define one or more \"matchers\" that determine if a BoundSymbol is a \"match\", and add them to the \n", - "# pattern using its match() method\n", - "\n", - "# The matcher signature not only accepts a BoundSymbol to review, but also a list of BoundSymbols that were\n", - "# already matched by the pattern, and a match_ctx dictionary that contains whatever state you like from previous matches\n", - "# The matcher returns True if the BoundSymbol should be matched, and False otherwise. When returning True it should return\n", - "# a dict that will be used to update the match_ctx for future matches. This will be clearer in a moment with an example.\n", - "# The following matcher is very permissive, and it matches any add operation.\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "a = torch.randn((2, 2))\n", - "b = torch.randn((2, 2))\n", - "\n", - "# An example program that performs an addition and a subtraction\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = a - b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# The matcher is told to match any addition\n", - "p.match(add_matcher)\n", - "\n", - "# Calling the Pattern object on a trace returns a list of matches. \n", - "# Each match is a list of (int, BoundSymbol) tuples, where int is the \n", - "# position of the BoundSymbol in the trace.\n", - "matches = p(trc)\n", - "\n", - "# In this case there is just one match -- the first addition\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\")], [(3, t1 = ltorch.add(a, b, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # t1 = prims.add(a, b) # t1: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "def foo(a, b):\n", - " c = a + b\n", - " d = a + b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# When the program is changed to include two additions, both additions\n", - "# are matched and two matches are created.\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (3, t1 = ltorch.add(a, b, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # t1 = prims.add(a, b) # t1: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# In addition to matching a single operation, a pattern can match any number of sequential operations --\n", - "# that is, operations that are immediately adjacent to each other. We do this by providing\n", - "# max_times and (optionally) min_times arguments to match()\n", - "# Negative max_times values are interpreted as matching the pattern any number of times\n", - "# Matching multiple operations occurs greedily and before any additional matching can occur\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher, min_times=1, max_times=-1)\n", - "\n", - "# The pattern now matches once, and the single match contains both additions\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (3, t1 = ltorch.sub(a, b, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # t1 = prims.sub(a, b) # t1: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Multiple operations can also be matched by calling match() multiple times. Each \n", - "# match() attempts to evaluate itself in the order it's called.\n", - "\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = a - b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# Let's match an addition followed by a subtraction on the same inputs. This will also show how to update the match_ctx dict\n", - "# and let us use the bind_names() utility.\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " # bind_names() produces an object with properties corresponding to the function's (Symbol's) parameters, when\n", - " # accessed they return their corresponding arguments\n", - " bn = bind_names(bsym)\n", - " # Stores the inputs in the context\n", - " return True, {'a': bn.a, 'b': bn.b}\n", - " \n", - " return False, None\n", - "\n", - "def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'sub':\n", - " bn = bind_names(bsym)\n", - "\n", - " # Acquires the previously stored values from the match_ctx\n", - " a = match_ctx['a']\n", - " b = match_ctx['b']\n", - "\n", - " # Matches the sub only if the arguments are the same as the addition's, and in the same order\n", - " if a is bn.a and b is bn.b:\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(sub_matcher)\n", - "\n", - "# Matches the addition and the subtraction\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Another version of the above example that uses the previously_matched argument to decide whether to match\n", - "# the subtraction\n", - "\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = a - b\n", - " return c, d\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " # Doesn't update the context -- the context is just scratch space for you\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'sub':\n", - " my_bn = bind_names(bsym)\n", - "\n", - " add_bsym = previously_matched\n", - " add_bn = bind_names(add_bsym)\n", - "\n", - " # Matches the sub only if the arguments are the same as the addition's, and in the same order\n", - " if add_bn.a is my_bn.a and add_bn.b is my_bn.b:\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(sub_matcher)\n", - "\n", - "# Matches the addition and the subtraction\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (4, t2 = ltorch.sub(a, b, alpha=None) # t2: \"cpu f32[2, 2]\"\n", - " # t2 = prims.sub(a, b) # t2: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Operations in a pattern don't have to be next to each other, but they do have to within a \"window\" of \n", - "# the previous operation. Currently the window is 5 operations. Each operation also has to be \n", - "# \"reorderable\" to be \"next to\" operations that were already matched. This isn't always\n", - "# possible. If an operation consumes an input that is not directly from a previously matched symbol, but\n", - "# is derived from the output of a previously matched symbol, then it cannot be reordered adjacent to the \n", - "# other operations in the pattern.\n", - "# Let's see how this works with two examples.\n", - "\n", - "# An operation between the first addition and second subtraction doesn't stop the previous pattern\n", - "# from matching as expected, because the operation producing d can be reordered to be \n", - "# immediately after the operation producing c\n", - "def foo(a, b):\n", - " c = a + b\n", - " x = a + 2\n", - " d = a - b\n", - " return c, d, x\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# The match is the same as when x isn't computed\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[]\n" - ] - } - ], - "source": [ - "# Too many intervening operations pushes the subtraction out of pattern matching \"window\" and prevents\n", - "# the match\n", - "# In the future we may expose an option to set the window larger -- share your thoughts by filing an issue!\n", - "def foo(a, b):\n", - " c = a + b\n", - " x = a + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " x = x + 2\n", - " d = a - b\n", - " return c, d, x\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "# No matches because the computation of c and the computation of d are separated by too many operations\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[]\n" - ] - } - ], - "source": [ - "# The computation of e depends on the computation of d and the computation of c\n", - "def foo(a, b):\n", - " c = a + b\n", - " d = c - 5\n", - " e = c + d\n", - " return e\n", - "trc = thunder.trace()(foo, a, b)\n", - "\n", - "def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'add':\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(add_matcher)\n", - "\n", - "# Attempting to match two additions fails, because the computation of e cannot be reordered next to the computation of c\n", - "matches = p(trc)\n", - "print(matches)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[(2, t0 = ltorch.add(a, b, alpha=None) # t0: \"cpu f32[2, 2]\"\n", - " # t0 = prims.add(a, b) # t0: \"cpu f32[2, 2]\"), (3, t1 = ltorch.sub(t0, 5, alpha=None) # t1: \"cpu f32[2, 2]\"\n", - " # _ = prims.convert_element_type(5, float)\n", - " # t1 = prims.sub(t0, 5.0) # t1: \"cpu f32[2, 2]\"), (4, t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[2, 2]\"\n", - " # t2 = prims.add(t0, t1) # t2: \"cpu f32[2, 2]\")]]\n" - ] - } - ], - "source": [ - "# Including the subtraction in the pattern allows it to be matched\n", - "def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:\n", - " if bsym.sym.name == 'sub':\n", - " return True, {}\n", - " \n", - " return False, None\n", - "\n", - "p = Pattern()\n", - "p.match(add_matcher)\n", - "p.match(sub_matcher)\n", - "p.match(add_matcher)\n", - "\n", - "matches = p(trc)\n", - "print(matches)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/fsdp_tutorial.ipynb b/notebooks/fsdp_tutorial.ipynb deleted file mode 100644 index 71ed1b1005..0000000000 --- a/notebooks/fsdp_tutorial.ipynb +++ /dev/null @@ -1,1489 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## FSDP Tutorial\n", - "\n", - "In this tutorial, we will walk through the implementation of Fully Sharded Data Parallel (FSDP) with Zero2 sharding strategy in `thunder`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Introduction\n", - "\n", - "In recent times, the LLM models have grown so large that all the model parameters don't fit on a single GPU. To circumvent this problem, there are various strategies like Tensor Parallel, Pipeline Parallel, Fully Sharded Data Parallel, etc to train these large models. In this tutorial, we discuss and implement Zero2 strategy for Fully Sharded Data Parallel (FSDP).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### What is Zero2 strategy for FSDP?\n", - "\n", - "In this strategy, we shard the model parameters across all the availabe GPUs. That is each GPU holds onto only a chunk of the parameter. During the forward pass, all GPUs call `all_gather` communication primitive to gather the parameters from other GPUs. Unlike Zero3 strategy which frees the parameter after forward pass, we save these unsharded parameters for backward pass. This is to save the overhead of extra communication. In the backward pass, we utilize the saved parameters and compute the gradients. Once the gradients are computed, we use `reduce_scatter` communication primitive to reduce (average) the gradients across all GPUs and scatter those gradients so that a given GPU holds only a chunk of gradient.\n", - "\n", - "For more information on FSDP, we recommend reading\n", - "\n", - "1. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel - [Link](https://arxiv.org/abs/2304.11277)\n", - "2. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models - [Link](https://arxiv.org/abs/1910.02054)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Example Model\n", - "\n", - "For this example we will have a simple model `Linear(Tanh(Linear(x)))` which will be sharded over 2 GPUs\n", - "\n", - "**NOTE**: We are generating the abstract trace so we don't actually need a system with 2 GPUs for this. It is only required when we execute this trace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.distributed\n", - "import thunder\n", - "import thunder.distributed\n", - "from IPython.display import Code" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "device='cuda'\n", - "dim = 64\n", - "def create_model():\n", - " layers = [torch.nn.Linear(dim, dim, bias=False),\n", - " torch.nn.Tanh(),\n", - " torch.nn.Linear(dim, dim, bias=False)]\n", - " return torch.nn.Sequential(*layers).to(device)\n", - "\n", - "# Model\n", - "model = create_model()\n", - "# Input\n", - "x = torch.randn(dim, dim, device=device)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def wrap_as_highlighted_code(trace):\n", - " return Code(str(trace), language=\"python\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 1 : Configuration\n", - "\n", - "For our implementation of FSDP, we will generate the trace where we are sharding our model over 2 GPU" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# FSDP Config \n", - "# Usually these values are set in the environment by `torchrun` but for this example\n", - "# we will set them ourselves\n", - "world_size = 2 # We have two processes.\n", - "global_rank = 0 # Current process is the very first process." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 2: Function to shard parameters\n", - "\n", - "Next step is to write a function which will actually shard the parameters over 0-dim." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "# NOTE: We shard over 0th dimension of the param.\n", - "def shard_param(param: torch.Tensor, rank: int, world_size: int, name: str) -> None:\n", - " # We will keep it simple and error if param's 0th dim is not divisible by ``world_size``.\n", - " # Alternative is that we can pad our parameters so that they are divisible by `world_size`.\n", - " assert param.shape[0] % world_size == 0,(\n", - " f\"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})\"\n", - " f\" to be divisible by the world size ({world_size})\"\n", - " )\n", - " chunk_size = param.shape[0] // world_size\n", - "\n", - " # rank helps us determine which chunk of the parameter we will hold.\n", - " shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()\n", - " param.data = shard\n", - "\n", - "# Shard each parameter of the model\n", - "for param_name, param in model.named_parameters():\n", - " shard_param(param, global_rank, world_size, param_name)\n", - " # Mark the param to denote that it is sharded.\n", - " # This is required by the synchronization primitive we will use below.\n", - " param.ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Sequential(\n", - " (0): Linear(in_features=64, out_features=64, bias=False)\n", - " (1): Tanh()\n", - " (2): Linear(in_features=64, out_features=64, bias=False)\n", - ")" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Verify our model looks as expected\n", - "model" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# Let us verify that we have actually sharded the parameters.\n", - "# Checking if the weight of 1st Linear layer is sharded over 0th dim.\n", - "assert model[0].weight.shape == (dim / world_size, dim)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 3: Add an operation to synchronize the parameters before calling the model.forward." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We have to create a process group. This is needed because the synchronization primitive `synchronize` that we will use to gather and scatter our weights in forward and backward requires a process group." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# Create a process group\n", - "options = torch.distributed.distributed_c10d.ProcessGroup.Options(backend=\"nccl\")\n", - "process_group = torch.distributed.distributed_c10d.ProcessGroup(torch.distributed.distributed_c10d.Store(),\n", - " global_rank, world_size, options)\n", - "torch.distributed.distributed_c10d.GroupMember.WORLD = process_group" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# `preprocess` gives us the functional version of the model which\n", - "# takes as inputs all the parameters and the expected arguments.\n", - "# NOTE: `thunder.common.preprocess` is not meant for general use\n", - "# and used only for brevity of code. It will be updated\n", - "# to a newer mechanism which is meant to be public facing. \n", - "functional_forward = thunder.common.preprocess(model, is_module=True)\n", - "\n", - "# This function creates a model with synchronization\n", - "# before calling the forward pass.\n", - "def model_with_syncs(*params, x):\n", - " # We call `prims.synchronize` on all the parameters.\n", - " # This is essentially calling `all_gather` so that we have the complete\n", - " # parameter before we actually to the forward computation.\n", - " unsharded_params = []\n", - " for param in params:\n", - " unsharded_params.append(thunder.distributed.prims.synchronize(param, process_group))\n", - "\n", - " return functional_forward(*unsharded_params, x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let us now see what the trace of our model looks like with all the synchronization.\n", - "\n", - "Two main observations regarding the below trace \n", - "1. We can observe the `prims.synchronize` that we inserted using `model_with_syncs`.\n", - "2. Output of the `prims.synchronize` have the shape of unsharded (original) parameter.\n", - "\n", - "With this, we have implemented the FSDP for the forward pass of our model." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
-       "import thunder\n",
-       "import thunder.distributed.prims\n",
-       "import thunder.torch as ltorch\n",
-       "import torch\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def model_with_syncs(*params, x):\n",
-       "  # params \n",
-       "  # x \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  = params\n",
-       "  t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t2\n",
-       "  t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t3\n",
-       "  t4 = ltorch.linear(x, t2, None)  # t4\n",
-       "    # t4 = prims.linear(x, t2, None)  # t4\n",
-       "  t5 = ltorch.tanh(t4)  # t5\n",
-       "    # t5 = prims.tanh(t4)  # t5\n",
-       "  t6 = ltorch.linear(t5, t3, None)  # t6\n",
-       "    # t6 = prims.linear(t5, t3, None)  # t6\n",
-       "  return t6\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{model\\PYZus{}with\\PYZus{}syncs}\\PY{p}{(}\\PY{o}{*}\\PY{n}{params}\\PY{p}{,} \\PY{n}{x}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} params }\n", - " \\PY{c+c1}{\\PYZsh{} x }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{params}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3}\n", - " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4}\n", - " \\PY{c+c1}{\\PYZsh{} t4 = prims.linear(x, t2, None) \\PYZsh{} t4}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = prims.tanh(t4) \\PYZsh{} t5}\n", - " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6}\n", - " \\PY{c+c1}{\\PYZsh{} t6 = prims.linear(t5, t3, None) \\PYZsh{} t6}\n", - " \\PY{k}{return} \\PY{n}{t6}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", - "import thunder\n", - "import thunder.distributed.prims\n", - "import thunder.torch as ltorch\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def model_with_syncs(*params, x):\n", - " # params \n", - " # x \n", - " t0, \\\n", - " t1, \\\n", - " = params\n", - " t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0) # t2\n", - " t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0) # t3\n", - " t4 = ltorch.linear(x, t2, None) # t4\n", - " # t4 = prims.linear(x, t2, None) # t4\n", - " t5 = ltorch.tanh(t4) # t5\n", - " # t5 = prims.tanh(t4) # t5\n", - " t6 = ltorch.linear(t5, t3, None) # t6\n", - " # t6 = prims.linear(t5, t3, None) # t6\n", - " return t6" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trace = thunder.trace()(model_with_syncs, *model.parameters(), x=x)\n", - "\n", - "wrap_as_highlighted_code(trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For backward, we don't have to do anything because `thunder` already knows how to compute the backward of `prims.synchronize`. We can verify that by using the `value_and_grad` transform to generate the complete forward and backward trace together.\n", - "\n", - "Observations for the trace below:\n", - "1. `prims.synchronize` from previous trace is now decomposed into `prims.all_gather` and `prims.wait`. So, we can clearly see that we make a communication call to gather the parameter (which is asynchronous) and wait till we have the complete parameter.\n", - "2. At the end of the trace (after the forward and the backward computation), we see calls to `prims.reduce_scatter` and `prims.wait`. This takes care of reducing the gradients across all the GPUs and sharding them. One thing to note, for averaging gradients with low dynamic range dtype like `float16`, if we naively sum the gradients across GPUs before dividing by `world_size`, it can lead to overflows. So we scale the gradient tensor with `world_size`, before calling `reduce_scatter` with `sum` reduction to effectively average the gradients without overflow." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
-       "import thunder\n",
-       "import thunder.core.devices as devices\n",
-       "import thunder.core.dtypes as dtypes\n",
-       "import thunder.core.prims as prims\n",
-       "import thunder.distributed.prims\n",
-       "import thunder.torch as ltorch\n",
-       "import torch\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def _value_and_grad(*args, **kwargs):\n",
-       "  # args \n",
-       "  # kwargs \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  = args\n",
-       "  t2 = kwargs['x']\n",
-       "  t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3\n",
-       "  p4 = thunder.distributed.prims.all_gather(t0, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p4\n",
-       "  t5 = thunder.distributed.prims.wait(p4)  # t5\n",
-       "  p6 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p6\n",
-       "  t7 = thunder.distributed.prims.wait(p6)  # t7\n",
-       "  t8 = prims.linear(t2, t5, None)  # t8\n",
-       "  t9 = prims.tanh(t8)  # t9\n",
-       "  t10 = prims.linear(t9, t7, None)  # t10\n",
-       "  t11 = ltorch.reshape(t3, -1, 64)  # t11\n",
-       "    # t11 = prims.reshape(t3, (64, 64))  # t11\n",
-       "  t12 = ltorch.matmul(t11, t7)  # t12\n",
-       "    # t12 = prims.matmul(t11, t7)  # t12\n",
-       "  t13 = ltorch.reshape(t3, -1, 64)  # t13\n",
-       "    # t13 = prims.reshape(t3, (64, 64))  # t13\n",
-       "  t14 = prims.transpose(t13, (1, 0))  # t14\n",
-       "  t15 = ltorch.reshape(t9, -1, 64)  # t15\n",
-       "    # t15 = prims.reshape(t9, (64, 64))  # t15\n",
-       "  t16 = ltorch.matmul(t14, t15)  # t16\n",
-       "    # t16 = prims.matmul(t14, t15)  # t16\n",
-       "  t17 = ltorch.mul(t9, t9)  # t17\n",
-       "    # t17 = prims.mul(t9, t9)  # t17\n",
-       "  t18 = ltorch.sub(1.0, t17, alpha=None)  # t18\n",
-       "    # t18 = prims.sub(1.0, t17)  # t18\n",
-       "  t19 = ltorch.mul(t12, t18)  # t19\n",
-       "    # t19 = prims.mul(t12, t18)  # t19\n",
-       "  t20 = ltorch.reshape(t19, -1, 64)  # t20\n",
-       "    # t20 = prims.reshape(t19, (64, 64))  # t20\n",
-       "  t21 = ltorch.matmul(t20, t5)  # t21\n",
-       "    # t21 = prims.matmul(t20, t5)  # t21\n",
-       "  t22 = ltorch.reshape(t19, -1, 64)  # t22\n",
-       "    # t22 = prims.reshape(t19, (64, 64))  # t22\n",
-       "  t23 = prims.transpose(t22, (1, 0))  # t23\n",
-       "  t24 = ltorch.reshape(t2, -1, 64)  # t24\n",
-       "    # t24 = prims.reshape(t2, (64, 64))  # t24\n",
-       "  t25 = ltorch.matmul(t23, t24)  # t25\n",
-       "    # t25 = prims.matmul(t23, t24)  # t25\n",
-       "  t26 = ltorch.true_divide(t16, 2)  # t26\n",
-       "    # _ = prims.convert_element_type(2, float)\n",
-       "    # t26 = prims.div(t16, 2.0)  # t26\n",
-       "  p27 = thunder.distributed.prims.reduce_scatter(t26, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p27\n",
-       "  t28 = thunder.distributed.prims.wait(p27)  # t28\n",
-       "  t29 = ltorch.true_divide(t25, 2)  # t29\n",
-       "    # _ = prims.convert_element_type(2, float)\n",
-       "    # t29 = prims.div(t25, 2.0)  # t29\n",
-       "  p30 = thunder.distributed.prims.reduce_scatter(t29, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p30\n",
-       "  t31 = thunder.distributed.prims.wait(p30)  # t31\n",
-       "  return (t10, (t31, t28, {'x': t21}))\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{devices} \\PY{k}{as} \\PY{n+nn}{devices}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{dtypes} \\PY{k}{as} \\PY{n+nn}{dtypes}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{,} \\PY{o}{*}\\PY{o}{*}\\PY{n}{kwargs}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args }\n", - " \\PY{c+c1}{\\PYZsh{} kwargs }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{kwargs}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3}\n", - " \\PY{n}{p4} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p4}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5}\n", - " \\PY{n}{p6} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p6}\n", - " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p6}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7}\n", - " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8}\n", - " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9}\n", - " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10}\n", - " \\PY{n}{t11} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t11}\n", - " \\PY{c+c1}{\\PYZsh{} t11 = prims.reshape(t3, (64, 64)) \\PYZsh{} t11}\n", - " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t11}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12}\n", - " \\PY{c+c1}{\\PYZsh{} t12 = prims.matmul(t11, t7) \\PYZsh{} t12}\n", - " \\PY{n}{t13} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t13}\n", - " \\PY{c+c1}{\\PYZsh{} t13 = prims.reshape(t3, (64, 64)) \\PYZsh{} t13}\n", - " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t13}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14}\n", - " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = prims.reshape(t9, (64, 64)) \\PYZsh{} t15}\n", - " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t14}\\PY{p}{,} \\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = prims.matmul(t14, t15) \\PYZsh{} t16}\n", - " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = prims.mul(t9, t9) \\PYZsh{} t17}\n", - " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t17}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = prims.sub(1.0, t17) \\PYZsh{} t18}\n", - " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t12}\\PY{p}{,} \\PY{n}{t18}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = prims.mul(t12, t18) \\PYZsh{} t19}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = prims.reshape(t19, (64, 64)) \\PYZsh{} t20}\n", - " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = prims.matmul(t20, t5) \\PYZsh{} t21}\n", - " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = prims.reshape(t19, (64, 64)) \\PYZsh{} t22}\n", - " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t22}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23}\n", - " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = prims.reshape(t2, (64, 64)) \\PYZsh{} t24}\n", - " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{n}{t24}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = prims.matmul(t23, t24) \\PYZsh{} t25}\n", - " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = prims.div(t16, 2.0) \\PYZsh{} t26}\n", - " \\PY{n}{p27} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t26}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p27}\n", - " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p27}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28}\n", - " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t25}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = prims.div(t25, 2.0) \\PYZsh{} t29}\n", - " \\PY{n}{p30} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t29}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p30}\n", - " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{n}{t28}\\PY{p}{,} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t21}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", - "import thunder\n", - "import thunder.core.devices as devices\n", - "import thunder.core.dtypes as dtypes\n", - "import thunder.core.prims as prims\n", - "import thunder.distributed.prims\n", - "import thunder.torch as ltorch\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def _value_and_grad(*args, **kwargs):\n", - " # args \n", - " # kwargs \n", - " t0, \\\n", - " t1, \\\n", - " = args\n", - " t2 = kwargs['x']\n", - " t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3\n", - " p4 = thunder.distributed.prims.all_gather(t0, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p4\n", - " t5 = thunder.distributed.prims.wait(p4) # t5\n", - " p6 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p6\n", - " t7 = thunder.distributed.prims.wait(p6) # t7\n", - " t8 = prims.linear(t2, t5, None) # t8\n", - " t9 = prims.tanh(t8) # t9\n", - " t10 = prims.linear(t9, t7, None) # t10\n", - " t11 = ltorch.reshape(t3, -1, 64) # t11\n", - " # t11 = prims.reshape(t3, (64, 64)) # t11\n", - " t12 = ltorch.matmul(t11, t7) # t12\n", - " # t12 = prims.matmul(t11, t7) # t12\n", - " t13 = ltorch.reshape(t3, -1, 64) # t13\n", - " # t13 = prims.reshape(t3, (64, 64)) # t13\n", - " t14 = prims.transpose(t13, (1, 0)) # t14\n", - " t15 = ltorch.reshape(t9, -1, 64) # t15\n", - " # t15 = prims.reshape(t9, (64, 64)) # t15\n", - " t16 = ltorch.matmul(t14, t15) # t16\n", - " # t16 = prims.matmul(t14, t15) # t16\n", - " t17 = ltorch.mul(t9, t9) # t17\n", - " # t17 = prims.mul(t9, t9) # t17\n", - " t18 = ltorch.sub(1.0, t17, alpha=None) # t18\n", - " # t18 = prims.sub(1.0, t17) # t18\n", - " t19 = ltorch.mul(t12, t18) # t19\n", - " # t19 = prims.mul(t12, t18) # t19\n", - " t20 = ltorch.reshape(t19, -1, 64) # t20\n", - " # t20 = prims.reshape(t19, (64, 64)) # t20\n", - " t21 = ltorch.matmul(t20, t5) # t21\n", - " # t21 = prims.matmul(t20, t5) # t21\n", - " t22 = ltorch.reshape(t19, -1, 64) # t22\n", - " # t22 = prims.reshape(t19, (64, 64)) # t22\n", - " t23 = prims.transpose(t22, (1, 0)) # t23\n", - " t24 = ltorch.reshape(t2, -1, 64) # t24\n", - " # t24 = prims.reshape(t2, (64, 64)) # t24\n", - " t25 = ltorch.matmul(t23, t24) # t25\n", - " # t25 = prims.matmul(t23, t24) # t25\n", - " t26 = ltorch.true_divide(t16, 2) # t26\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t26 = prims.div(t16, 2.0) # t26\n", - " p27 = thunder.distributed.prims.reduce_scatter(t26, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p27\n", - " t28 = thunder.distributed.prims.wait(p27) # t28\n", - " t29 = ltorch.true_divide(t25, 2) # t29\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t29 = prims.div(t25, 2.0) # t29\n", - " p30 = thunder.distributed.prims.reduce_scatter(t29, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p30\n", - " t31 = thunder.distributed.prims.wait(p30) # t31\n", - " return (t10, (t31, t28, {'x': t21}))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from thunder.core.transforms import value_and_grad\n", - "\n", - "forward_and_backward_model = value_and_grad(model_with_syncs)\n", - "\n", - "forward_backward_trace = thunder.trace()(forward_and_backward_model, *model.parameters(), x=x)\n", - "\n", - "wrap_as_highlighted_code(forward_backward_trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above trace, only contains primitive which specifies the semantic of an operation abstractly but doesn't perform the actual computation.\n", - "\n", - "Now we will generate the execution trace which can actually perform the compute.\n", - "\n", - "In the execution trace generated below, we can see that all the primitives have been replaced with actually PyTorch operations. Also, our synchronization primitives have been replaced with PyTorch implementation provided by thunder i.e. `torch_all_gather_prim_impl`, `torch_reduce_scatter_prim_impl`, `torch_wait_prim_impl`." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Delete Last Used (took 0 milliseconds)\n",
-       "import torch\n",
-       "import torch.nn.functional\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def _value_and_grad(*args, **kwargs):\n",
-       "  # args \n",
-       "  # kwargs \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  = args\n",
-       "  del args\n",
-       "  t2 = kwargs['x']\n",
-       "  del kwargs\n",
-       "  t3 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3\n",
-       "    # t3 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3\n",
-       "      # t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3\n",
-       "  p4 = torch_all_gather_prim_impl(t0, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p4\n",
-       "  del t0\n",
-       "  t5 = torch_wait_prim_impl(p4)  # t5\n",
-       "  del p4\n",
-       "  p6 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p6\n",
-       "  del t1\n",
-       "  t7 = torch_wait_prim_impl(p6)  # t7\n",
-       "  del p6\n",
-       "  t8 = torch.nn.functional.linear(t2, t5, None)  # t8\n",
-       "    # t8 = ltorch.linear(t2, t5, None)  # t8\n",
-       "      # t8 = prims.linear(t2, t5, None)  # t8\n",
-       "  t9 = torch.tanh(t8)  # t9\n",
-       "    # t9 = ltorch.tanh(t8)  # t9\n",
-       "      # t9 = prims.tanh(t8)  # t9\n",
-       "  del t8\n",
-       "  t10 = torch.nn.functional.linear(t9, t7, None)  # t10\n",
-       "    # t10 = ltorch.linear(t9, t7, None)  # t10\n",
-       "      # t10 = prims.linear(t9, t7, None)  # t10\n",
-       "  t11 = torch.reshape(t3, (-1, 64))  # t11\n",
-       "    # t11 = ltorch.reshape(t3, (-1, 64))  # t11\n",
-       "      # t11 = prims.reshape(t3, (64, 64))  # t11\n",
-       "  t12 = torch.matmul(t11, t7)  # t12\n",
-       "    # t12 = ltorch.matmul(t11, t7)  # t12\n",
-       "      # t12 = prims.matmul(t11, t7)  # t12\n",
-       "  del t11, t7\n",
-       "  t13 = torch.reshape(t3, (-1, 64))  # t13\n",
-       "    # t13 = ltorch.reshape(t3, (-1, 64))  # t13\n",
-       "      # t13 = prims.reshape(t3, (64, 64))  # t13\n",
-       "  del t3\n",
-       "  t14 = torch.permute(t13, (1, 0))  # t14\n",
-       "    # t14 = ltorch.permute(t13, (1, 0))  # t14\n",
-       "      # t14 = prims.transpose(t13, (1, 0))  # t14\n",
-       "  del t13\n",
-       "  t15 = torch.reshape(t9, (-1, 64))  # t15\n",
-       "    # t15 = ltorch.reshape(t9, (-1, 64))  # t15\n",
-       "      # t15 = prims.reshape(t9, (64, 64))  # t15\n",
-       "  t16 = torch.matmul(t14, t15)  # t16\n",
-       "    # t16 = ltorch.matmul(t14, t15)  # t16\n",
-       "      # t16 = prims.matmul(t14, t15)  # t16\n",
-       "  del t14, t15\n",
-       "  t17 = torch.mul(t9, t9)  # t17\n",
-       "    # t17 = ltorch.mul(t9, t9)  # t17\n",
-       "      # t17 = prims.mul(t9, t9)  # t17\n",
-       "  del t9\n",
-       "  t18 = torch.sub(1.0, t17)  # t18\n",
-       "    # t18 = ltorch.sub(1.0, t17, alpha=None)  # t18\n",
-       "      # t18 = prims.sub(1.0, t17)  # t18\n",
-       "  del t17\n",
-       "  t19 = torch.mul(t12, t18)  # t19\n",
-       "    # t19 = ltorch.mul(t12, t18)  # t19\n",
-       "      # t19 = prims.mul(t12, t18)  # t19\n",
-       "  del t12, t18\n",
-       "  t20 = torch.reshape(t19, (-1, 64))  # t20\n",
-       "    # t20 = ltorch.reshape(t19, (-1, 64))  # t20\n",
-       "      # t20 = prims.reshape(t19, (64, 64))  # t20\n",
-       "  t21 = torch.matmul(t20, t5)  # t21\n",
-       "    # t21 = ltorch.matmul(t20, t5)  # t21\n",
-       "      # t21 = prims.matmul(t20, t5)  # t21\n",
-       "  del t20, t5\n",
-       "  t22 = torch.reshape(t19, (-1, 64))  # t22\n",
-       "    # t22 = ltorch.reshape(t19, (-1, 64))  # t22\n",
-       "      # t22 = prims.reshape(t19, (64, 64))  # t22\n",
-       "  del t19\n",
-       "  t23 = torch.permute(t22, (1, 0))  # t23\n",
-       "    # t23 = ltorch.permute(t22, (1, 0))  # t23\n",
-       "      # t23 = prims.transpose(t22, (1, 0))  # t23\n",
-       "  del t22\n",
-       "  t24 = torch.reshape(t2, (-1, 64))  # t24\n",
-       "    # t24 = ltorch.reshape(t2, (-1, 64))  # t24\n",
-       "      # t24 = prims.reshape(t2, (64, 64))  # t24\n",
-       "  del t2\n",
-       "  t25 = torch.matmul(t23, t24)  # t25\n",
-       "    # t25 = ltorch.matmul(t23, t24)  # t25\n",
-       "      # t25 = prims.matmul(t23, t24)  # t25\n",
-       "  del t23, t24\n",
-       "  t26 = torch.true_divide(t16, 2)  # t26\n",
-       "    # t26 = ltorch.true_divide(t16, 2)  # t26\n",
-       "      # _ = prims.convert_element_type(2, float)\n",
-       "      # t26 = prims.div(t16, 2.0)  # t26\n",
-       "  del t16\n",
-       "  p27 = torch_reduce_scatter_prim_impl(t26, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p27\n",
-       "  del t26\n",
-       "  t28 = torch_wait_prim_impl(p27)  # t28\n",
-       "  del p27\n",
-       "  t29 = torch.true_divide(t25, 2)  # t29\n",
-       "    # t29 = ltorch.true_divide(t25, 2)  # t29\n",
-       "      # _ = prims.convert_element_type(2, float)\n",
-       "      # t29 = prims.div(t25, 2.0)  # t29\n",
-       "  del t25\n",
-       "  p30 = torch_reduce_scatter_prim_impl(t29, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p30\n",
-       "  del t29\n",
-       "  t31 = torch_wait_prim_impl(p30)  # t31\n",
-       "  del p30\n",
-       "  return (t10, (t31, t28, {'x': t21}))\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Delete Last Used (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{,} \\PY{o}{*}\\PY{o}{*}\\PY{n}{kwargs}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args }\n", - " \\PY{c+c1}{\\PYZsh{} kwargs }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{k}{del} \\PY{n}{args}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{kwargs}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\n", - " \\PY{k}{del} \\PY{n}{kwargs}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t3}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t3}\n", - " \\PY{n}{p4} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p4}\n", - " \\PY{k}{del} \\PY{n}{t0}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5}\n", - " \\PY{k}{del} \\PY{n}{p4}\n", - " \\PY{n}{p6} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p6}\n", - " \\PY{k}{del} \\PY{n}{t1}\n", - " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p6}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7}\n", - " \\PY{k}{del} \\PY{n}{p6}\n", - " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8}\n", - " \\PY{c+c1}{\\PYZsh{} t8 = ltorch.linear(t2, t5, None) \\PYZsh{} t8}\n", - " \\PY{c+c1}{\\PYZsh{} t8 = prims.linear(t2, t5, None) \\PYZsh{} t8}\n", - " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9}\n", - " \\PY{c+c1}{\\PYZsh{} t9 = ltorch.tanh(t8) \\PYZsh{} t9}\n", - " \\PY{c+c1}{\\PYZsh{} t9 = prims.tanh(t8) \\PYZsh{} t9}\n", - " \\PY{k}{del} \\PY{n}{t8}\n", - " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10}\n", - " \\PY{c+c1}{\\PYZsh{} t10 = ltorch.linear(t9, t7, None) \\PYZsh{} t10}\n", - " \\PY{c+c1}{\\PYZsh{} t10 = prims.linear(t9, t7, None) \\PYZsh{} t10}\n", - " \\PY{n}{t11} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t11}\n", - " \\PY{c+c1}{\\PYZsh{} t11 = ltorch.reshape(t3, (\\PYZhy{}1, 64)) \\PYZsh{} t11}\n", - " \\PY{c+c1}{\\PYZsh{} t11 = prims.reshape(t3, (64, 64)) \\PYZsh{} t11}\n", - " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t11}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12}\n", - " \\PY{c+c1}{\\PYZsh{} t12 = ltorch.matmul(t11, t7) \\PYZsh{} t12}\n", - " \\PY{c+c1}{\\PYZsh{} t12 = prims.matmul(t11, t7) \\PYZsh{} t12}\n", - " \\PY{k}{del} \\PY{n}{t11}\\PY{p}{,} \\PY{n}{t7}\n", - " \\PY{n}{t13} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t13}\n", - " \\PY{c+c1}{\\PYZsh{} t13 = ltorch.reshape(t3, (\\PYZhy{}1, 64)) \\PYZsh{} t13}\n", - " \\PY{c+c1}{\\PYZsh{} t13 = prims.reshape(t3, (64, 64)) \\PYZsh{} t13}\n", - " \\PY{k}{del} \\PY{n}{t3}\n", - " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t13}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14}\n", - " \\PY{c+c1}{\\PYZsh{} t14 = ltorch.permute(t13, (1, 0)) \\PYZsh{} t14}\n", - " \\PY{c+c1}{\\PYZsh{} t14 = prims.transpose(t13, (1, 0)) \\PYZsh{} t14}\n", - " \\PY{k}{del} \\PY{n}{t13}\n", - " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = ltorch.reshape(t9, (\\PYZhy{}1, 64)) \\PYZsh{} t15}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = prims.reshape(t9, (64, 64)) \\PYZsh{} t15}\n", - " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t14}\\PY{p}{,} \\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = ltorch.matmul(t14, t15) \\PYZsh{} t16}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = prims.matmul(t14, t15) \\PYZsh{} t16}\n", - " \\PY{k}{del} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t15}\n", - " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t9}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = ltorch.mul(t9, t9) \\PYZsh{} t17}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = prims.mul(t9, t9) \\PYZsh{} t17}\n", - " \\PY{k}{del} \\PY{n}{t9}\n", - " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{,} \\PY{n}{t17}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = ltorch.sub(1.0, t17, alpha=None) \\PYZsh{} t18}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = prims.sub(1.0, t17) \\PYZsh{} t18}\n", - " \\PY{k}{del} \\PY{n}{t17}\n", - " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t12}\\PY{p}{,} \\PY{n}{t18}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = ltorch.mul(t12, t18) \\PYZsh{} t19}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = prims.mul(t12, t18) \\PYZsh{} t19}\n", - " \\PY{k}{del} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t18}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.reshape(t19, (\\PYZhy{}1, 64)) \\PYZsh{} t20}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = prims.reshape(t19, (64, 64)) \\PYZsh{} t20}\n", - " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = ltorch.matmul(t20, t5) \\PYZsh{} t21}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = prims.matmul(t20, t5) \\PYZsh{} t21}\n", - " \\PY{k}{del} \\PY{n}{t20}\\PY{p}{,} \\PY{n}{t5}\n", - " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = ltorch.reshape(t19, (\\PYZhy{}1, 64)) \\PYZsh{} t22}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = prims.reshape(t19, (64, 64)) \\PYZsh{} t22}\n", - " \\PY{k}{del} \\PY{n}{t19}\n", - " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t22}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23}\n", - " \\PY{c+c1}{\\PYZsh{} t23 = ltorch.permute(t22, (1, 0)) \\PYZsh{} t23}\n", - " \\PY{c+c1}{\\PYZsh{} t23 = prims.transpose(t22, (1, 0)) \\PYZsh{} t23}\n", - " \\PY{k}{del} \\PY{n}{t22}\n", - " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = ltorch.reshape(t2, (\\PYZhy{}1, 64)) \\PYZsh{} t24}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = prims.reshape(t2, (64, 64)) \\PYZsh{} t24}\n", - " \\PY{k}{del} \\PY{n}{t2}\n", - " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{n}{t24}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = ltorch.matmul(t23, t24) \\PYZsh{} t25}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = prims.matmul(t23, t24) \\PYZsh{} t25}\n", - " \\PY{k}{del} \\PY{n}{t23}\\PY{p}{,} \\PY{n}{t24}\n", - " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = ltorch.true\\PYZus{}divide(t16, 2) \\PYZsh{} t26}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = prims.div(t16, 2.0) \\PYZsh{} t26}\n", - " \\PY{k}{del} \\PY{n}{t16}\n", - " \\PY{n}{p27} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t26}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p27}\n", - " \\PY{k}{del} \\PY{n}{t26}\n", - " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p27}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28}\n", - " \\PY{k}{del} \\PY{n}{p27}\n", - " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t25}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = ltorch.true\\PYZus{}divide(t25, 2) \\PYZsh{} t29}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = prims.div(t25, 2.0) \\PYZsh{} t29}\n", - " \\PY{k}{del} \\PY{n}{t25}\n", - " \\PY{n}{p30} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t29}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p30}\n", - " \\PY{k}{del} \\PY{n}{t29}\n", - " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31}\n", - " \\PY{k}{del} \\PY{n}{p30}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{n}{t28}\\PY{p}{,} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{x}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t21}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def _value_and_grad(*args, **kwargs):\n", - " # args \n", - " # kwargs \n", - " t0, \\\n", - " t1, \\\n", - " = args\n", - " del args\n", - " t2 = kwargs['x']\n", - " del kwargs\n", - " t3 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3\n", - " # t3 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3\n", - " # t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3\n", - " p4 = torch_all_gather_prim_impl(t0, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p4\n", - " del t0\n", - " t5 = torch_wait_prim_impl(p4) # t5\n", - " del p4\n", - " p6 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p6\n", - " del t1\n", - " t7 = torch_wait_prim_impl(p6) # t7\n", - " del p6\n", - " t8 = torch.nn.functional.linear(t2, t5, None) # t8\n", - " # t8 = ltorch.linear(t2, t5, None) # t8\n", - " # t8 = prims.linear(t2, t5, None) # t8\n", - " t9 = torch.tanh(t8) # t9\n", - " # t9 = ltorch.tanh(t8) # t9\n", - " # t9 = prims.tanh(t8) # t9\n", - " del t8\n", - " t10 = torch.nn.functional.linear(t9, t7, None) # t10\n", - " # t10 = ltorch.linear(t9, t7, None) # t10\n", - " # t10 = prims.linear(t9, t7, None) # t10\n", - " t11 = torch.reshape(t3, (-1, 64)) # t11\n", - " # t11 = ltorch.reshape(t3, (-1, 64)) # t11\n", - " # t11 = prims.reshape(t3, (64, 64)) # t11\n", - " t12 = torch.matmul(t11, t7) # t12\n", - " # t12 = ltorch.matmul(t11, t7) # t12\n", - " # t12 = prims.matmul(t11, t7) # t12\n", - " del t11, t7\n", - " t13 = torch.reshape(t3, (-1, 64)) # t13\n", - " # t13 = ltorch.reshape(t3, (-1, 64)) # t13\n", - " # t13 = prims.reshape(t3, (64, 64)) # t13\n", - " del t3\n", - " t14 = torch.permute(t13, (1, 0)) # t14\n", - " # t14 = ltorch.permute(t13, (1, 0)) # t14\n", - " # t14 = prims.transpose(t13, (1, 0)) # t14\n", - " del t13\n", - " t15 = torch.reshape(t9, (-1, 64)) # t15\n", - " # t15 = ltorch.reshape(t9, (-1, 64)) # t15\n", - " # t15 = prims.reshape(t9, (64, 64)) # t15\n", - " t16 = torch.matmul(t14, t15) # t16\n", - " # t16 = ltorch.matmul(t14, t15) # t16\n", - " # t16 = prims.matmul(t14, t15) # t16\n", - " del t14, t15\n", - " t17 = torch.mul(t9, t9) # t17\n", - " # t17 = ltorch.mul(t9, t9) # t17\n", - " # t17 = prims.mul(t9, t9) # t17\n", - " del t9\n", - " t18 = torch.sub(1.0, t17) # t18\n", - " # t18 = ltorch.sub(1.0, t17, alpha=None) # t18\n", - " # t18 = prims.sub(1.0, t17) # t18\n", - " del t17\n", - " t19 = torch.mul(t12, t18) # t19\n", - " # t19 = ltorch.mul(t12, t18) # t19\n", - " # t19 = prims.mul(t12, t18) # t19\n", - " del t12, t18\n", - " t20 = torch.reshape(t19, (-1, 64)) # t20\n", - " # t20 = ltorch.reshape(t19, (-1, 64)) # t20\n", - " # t20 = prims.reshape(t19, (64, 64)) # t20\n", - " t21 = torch.matmul(t20, t5) # t21\n", - " # t21 = ltorch.matmul(t20, t5) # t21\n", - " # t21 = prims.matmul(t20, t5) # t21\n", - " del t20, t5\n", - " t22 = torch.reshape(t19, (-1, 64)) # t22\n", - " # t22 = ltorch.reshape(t19, (-1, 64)) # t22\n", - " # t22 = prims.reshape(t19, (64, 64)) # t22\n", - " del t19\n", - " t23 = torch.permute(t22, (1, 0)) # t23\n", - " # t23 = ltorch.permute(t22, (1, 0)) # t23\n", - " # t23 = prims.transpose(t22, (1, 0)) # t23\n", - " del t22\n", - " t24 = torch.reshape(t2, (-1, 64)) # t24\n", - " # t24 = ltorch.reshape(t2, (-1, 64)) # t24\n", - " # t24 = prims.reshape(t2, (64, 64)) # t24\n", - " del t2\n", - " t25 = torch.matmul(t23, t24) # t25\n", - " # t25 = ltorch.matmul(t23, t24) # t25\n", - " # t25 = prims.matmul(t23, t24) # t25\n", - " del t23, t24\n", - " t26 = torch.true_divide(t16, 2) # t26\n", - " # t26 = ltorch.true_divide(t16, 2) # t26\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t26 = prims.div(t16, 2.0) # t26\n", - " del t16\n", - " p27 = torch_reduce_scatter_prim_impl(t26, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p27\n", - " del t26\n", - " t28 = torch_wait_prim_impl(p27) # t28\n", - " del p27\n", - " t29 = torch.true_divide(t25, 2) # t29\n", - " # t29 = ltorch.true_divide(t25, 2) # t29\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t29 = prims.div(t25, 2.0) # t29\n", - " del t25\n", - " p30 = torch_reduce_scatter_prim_impl(t29, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p30\n", - " del t29\n", - " t31 = torch_wait_prim_impl(p30) # t31\n", - " del p30\n", - " return (t10, (t31, t28, {'x': t21}))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "optimized_trace = thunder.transform_for_execution(forward_backward_trace, executors_list=thunder.get_always_executors())\n", - "\n", - "# Grab the final trace\n", - "exec_trace = optimized_trace[-1]\n", - "wrap_as_highlighted_code(exec_trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Step 4 : Running the actual computation\n", - "\n", - "Running the actual computation will require setting up 2 processes and running our above code in both those processes (which can be tricky with Jupyter Notebook). Instead, we will write a small script and run it with `torchrun` which takes care of setting up the processes and relevant state.\n", - "\n", - "**NOTE**: This requires device running this notebook to have at least 2-GPUs" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the example below, we will use `thunder.distributed.fsdp` which does the same as what we did above (with some extra checks). The code below should look familiar as it is roughly all the above pieces in a single script. " - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Overwriting thunder_fsdp_simple_example.py\n" - ] - } - ], - "source": [ - "%%writefile thunder_fsdp_simple_example.py\n", - "\n", - "# imports\n", - "from thunder.tests.lit_gpt_model import GPT, Config\n", - "import torch\n", - "import torch.distributed\n", - "import thunder\n", - "import thunder.distributed\n", - "import os\n", - "\n", - "# # # # # # # #\n", - "# Create Model\n", - "# # # # # # # #\n", - "\n", - "# NOTE: We create the model on CPU.\n", - "device='cpu'\n", - "dim = 64\n", - "def create_model():\n", - " layers = []\n", - " layers.append(torch.nn.Linear(dim, dim))\n", - " layers.append(torch.nn.ReLU())\n", - " layers.append(torch.nn.Linear(dim, dim))\n", - " return torch.nn.Sequential(*layers).to(device)\n", - "\n", - "# Model\n", - "model = create_model()\n", - "# Input\n", - "x = torch.randn(dim, dim, device=device)\n", - "\n", - "# # # # # # # #\n", - "# Setup for distributed\n", - "# # # # # # # #\n", - "torch.distributed.init_process_group(backend='nccl')\n", - "\n", - "rank = int(os.environ[\"LOCAL_RANK\"])\n", - "\n", - "device = f\"cuda:{rank}\"\n", - "\n", - "# # # # # # # #\n", - "# Move inputs to correct device\n", - "# # # # # # # #\n", - "x = x.to(device)\n", - "\n", - "# # # # # # # #\n", - "# Wrap the model in thunder.distributed.fsdp\n", - "# # # # # # # #\n", - "\n", - "# thunder.distributed.fsdp takes care of moving the parameter\n", - "# shard to the correct GPU for the current process.\n", - "cmodel = thunder.jit(thunder.distributed.fsdp(model))\n", - "\n", - "# Run the forward pass.\n", - "cmodel(x)\n", - "\n", - "# # # # # # # #\n", - "# Check the traces\n", - "# # # # # # # #\n", - "fwd_traces, bwd_traces = thunder.last_traces(cmodel)\n", - "\n", - "# # # # # # # #\n", - "# Print and check to see if they match ours\n", - "# # # # # # # #\n", - "if rank == 0:\n", - " print(fwd_traces[-1])\n", - " print(\"*******\"* 8)\n", - " print(bwd_traces[-1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let us run the above script and check what the trace looks like.\n", - "\n", - "We can observe that forward trace has `torch_all_gather_prim_impl` to gather the parameter before forward pass and the backward trace has `torch_reduce_scatter_prim_impl` to reduce and scatter the gradients back to different GPUs. This is similar to our implementation above." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] \n", - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] *****************************************\n", - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", - "[2024-03-06 15:59:54,829] torch.distributed.run: [WARNING] *****************************************\n", - "/home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "/home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def augmented_forward_fn(t_0_weight, t_0_bias, t_0, t_2_weight, t_2_bias):\n", - " # t_0_weight \n", - " p0 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p0\n", - " # t_0_bias \n", - " p2 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p2\n", - " # t_0 \n", - " # t_2_weight \n", - " p7 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p7\n", - " # t_2_bias \n", - " p9 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p9\n", - " t1 = torch_wait_prim_impl(p0) # t1\n", - " del p0\n", - " t3 = torch_wait_prim_impl(p2) # t3\n", - " del p2\n", - " t4 = torch.nn.functional.linear(t_0, t1, t3) # t4\n", - " # t4 = ltorch.linear(t_0, t1, t3) # t4\n", - " # t4 = prims.linear(t_0, t1, t3) # t4\n", - " del t1, t3\n", - " [t5, t6] = nvFusion0(t4)\n", - " # t5 = prims.gt(t4, 0.0) # t5\n", - " # t6 = prims.where(t5, t4, 0.0) # t6\n", - " del t4\n", - " t8 = torch_wait_prim_impl(p7) # t8\n", - " del p7\n", - " t10 = torch_wait_prim_impl(p9) # t10\n", - " del p9\n", - " t11 = torch.nn.functional.linear(t6, t8, t10) # t11\n", - " # t11 = ltorch.linear(t6, t8, t10) # t11\n", - " # t11 = prims.linear(t6, t8, t10) # t11\n", - " del t10\n", - " return {'output': (t11, ()), 'flat_args': [t_0_weight, t_0_bias, t_0, t_2_weight, t_2_bias], 'flat_output': (t11,)}, ((t5, t6, t8, t_0), ())\n", - "********************************************************\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def backward_fn(saved_for_backward, cotangents):\n", - " # saved_for_backward \n", - " # cotangents \n", - " C0, \\\n", - " _, \\\n", - " = saved_for_backward\n", - " clear_collection(saved_for_backward)\n", - " del saved_for_backward\n", - " t0, \\\n", - " = cotangents\n", - " clear_collection(cotangents)\n", - " del cotangents\n", - " t5, \\\n", - " t6, \\\n", - " t8, \\\n", - " t_0, \\\n", - " = C0\n", - " clear_collection(C0)\n", - " del C0\n", - " t31 = torch.reshape(t0, (-1, 64)) # t31\n", - " # t31 = ltorch.reshape(t0, (-1, 64)) # t31\n", - " # t31 = prims.reshape(t0, (64, 64)) # t31\n", - " t32 = torch.permute(t31, (1, 0)) # t32\n", - " # t32 = ltorch.permute(t31, (1, 0)) # t32\n", - " # t32 = prims.transpose(t31, (1, 0)) # t32\n", - " t33 = torch.reshape(t6, (-1, 64)) # t33\n", - " # t33 = ltorch.reshape(t6, (-1, 64)) # t33\n", - " # t33 = prims.reshape(t6, (64, 64)) # t33\n", - " del t6\n", - " t48 = torch.reshape(t_0, (-1, 64)) # t48\n", - " # t48 = ltorch.reshape(t_0, (-1, 64)) # t48\n", - " # t48 = prims.reshape(t_0, (64, 64)) # t48\n", - " del t_0\n", - " [t36] = nvFusion0(t0)\n", - " # t35 = prims.sum(t0, (0,)) # t35\n", - " # t36 = prims.div(t35, 2.0) # t36\n", - " del t0\n", - " p37 = torch_reduce_scatter_prim_impl(t36, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p37\n", - " del t36\n", - " t30 = torch.matmul(t31, t8) # t30\n", - " # t30 = ltorch.matmul(t29, t8) # t30\n", - " # t30 = prims.matmul(t29, t8) # t30\n", - " del t31, t8\n", - " t34 = torch.matmul(t32, t33) # t34\n", - " # t34 = ltorch.matmul(t32, t33) # t34\n", - " # t34 = prims.matmul(t32, t33) # t34\n", - " del t32, t33\n", - " [t39, t42, t51] = nvFusion1(t30, t34, t5)\n", - " # t42 = prims.where(t5, t30, 0.0) # t42\n", - " # t50 = prims.sum(t42, (0,)) # t50\n", - " # t51 = prims.div(t50, 2.0) # t51\n", - " # t39 = prims.div(t34, 2.0) # t39\n", - " del t30, t34, t5\n", - " p40 = torch_reduce_scatter_prim_impl(t39, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p40\n", - " del t39\n", - " p52 = torch_reduce_scatter_prim_impl(t51, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p52\n", - " del t51\n", - " t46 = torch.reshape(t42, (-1, 64)) # t46\n", - " # t46 = ltorch.reshape(t42, (-1, 64)) # t46\n", - " # t46 = prims.reshape(t42, (64, 64)) # t46\n", - " del t42\n", - " t47 = torch.permute(t46, (1, 0)) # t47\n", - " # t47 = ltorch.permute(t46, (1, 0)) # t47\n", - " # t47 = prims.transpose(t46, (1, 0)) # t47\n", - " del t46\n", - " t49 = torch.matmul(t47, t48) # t49\n", - " # t49 = ltorch.matmul(t47, t48) # t49\n", - " # t49 = prims.matmul(t47, t48) # t49\n", - " del t47, t48\n", - " [t54] = nvFusion2(t49)\n", - " # t54 = prims.div(t49, 2.0) # t54\n", - " del t49\n", - " p55 = torch_reduce_scatter_prim_impl(t54, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p55\n", - " del t54\n", - " t38 = torch_wait_prim_impl(p37) # t38\n", - " del p37\n", - " t41 = torch_wait_prim_impl(p40) # t41\n", - " del p40\n", - " t53 = torch_wait_prim_impl(p52) # t53\n", - " del p52\n", - " t56 = torch_wait_prim_impl(p55) # t56\n", - " del p55\n", - " return (t56, t53, None, t41, t38)\n" - ] - } - ], - "source": [ - "!torchrun --nproc_per_node=2 thunder_fsdp_simple_example.py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Conclusion\n", - "\n", - "We have created our implementation of FSDP to shard our model across multiple GPUs. In the process, we also learned that:\n", - "\n", - "1. `thunder` provides us with primitives for synchronization across mutiple GPUs.\n", - "2. `thunder` also takes care of implementing the backward support for the synchronization primitives, so we don't have to explicitly do anything to get the backward working.\n", - "3. We can just easily apply `thunder.distributed.fsdp` to our model and it will take care of sharding the parameters and also adding synchronizations to our model. Also, we can easily check the modifications by inspecting the traces." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytorch-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb index 9c7a5468a3..a1a888cc72 100644 --- a/notebooks/zero_to_thunder.ipynb +++ b/notebooks/zero_to_thunder.ipynb @@ -3,275 +3,4189 @@ { "cell_type": "markdown", "id": "1638964c", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "# Zero to thunder" + "# Zero to Thunder\n", + "\n", + "Here we take a very short tour of what is possible with Thunder.\n", + "\n", + "To get started we import it (and a bunch of things for this notebook)." ] }, { "cell_type": "code", - "execution_count": 5, - "id": "e8953e57", + "execution_count": 1, + "id": "28b99b58", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '..')\n", "\n", - "import torch, thunder\n", + "import torch, thunder" + ] + }, + { + "cell_type": "markdown", + "id": "54f87aba", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Compiling a first module with Thunder\n", "\n", - "from thunder.tests.lit_gpt_model import Config, Block" + "So let's get started! As a \"Hello World\", let us apply it to it to a small model, say, the MLP part found in Llama 2. We take it from LitGPT." ] }, { "cell_type": "code", - "execution_count": 37, - "id": "0a62c587", + "execution_count": 2, + "id": "892be718", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LLaMAMLP(\n", + " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + ")\n" + ] + } + ], + "source": [ + "class LLaMAMLP(torch.nn.Module):\n", + " def __init__(self, n_embd, intermediate_size) -> None:\n", + " super().__init__()\n", + " self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", + " self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", + " self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x_fc_1 = self.fc_1(x)\n", + " x_fc_2 = self.fc_2(x)\n", + " x = torch.nn.functional.silu(x_fc_1) * x_fc_2\n", + " return self.proj(x)\n", + "with torch.device(\"cuda\"):\n", + " m = LLaMAMLP(4096, 11008)\n", + "for p in m.parameters():\n", + " p.requires_grad_(False)\n", + "print(m)\n" + ] + }, + { + "cell_type": "markdown", + "id": "702ea054", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "from lit_gpt.model import Config, LLaMAMLP" + "Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`" ] }, { "cell_type": "code", - "execution_count": 51, - "id": "d6ca6328", + "execution_count": 3, + "id": "67ca2d37", "metadata": {}, "outputs": [], "source": [ - "cfg = Config.from_name('Llama-2-7b-hf')\n", - "with torch.device(\"cuda\"):\n", - " m = LLaMAMLP(cfg)\n", - "\n" + "thunder_model = thunder.jit(m)" ] }, { "cell_type": "code", - "execution_count": 52, - "id": "3a159966", + "execution_count": 4, + "id": "964e2689", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "LLaMAMLP(\n", - " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + "ThunderModule(\n", + " (_model): LLaMAMLP(\n", + " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + " )\n", ")" ] }, - "execution_count": 52, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "m" + "thunder_model" + ] + }, + { + "cell_type": "markdown", + "id": "47d24f2d-0e89-4fe8-8154-9b50f2633e1b", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance." ] }, { "cell_type": "code", - "execution_count": 53, - "id": "67ca2d37", - "metadata": {}, - "outputs": [], + "execution_count": 5, + "id": "7f4de1b3", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 1.4901161193847656e-07\n", + "61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", + "62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "x = torch.randn(2, 2048, 4096, device=\"cuda\")\n", + "print('deviation:', (thunder_model(x) - m(x)).abs().max().item())\n", + "\n", + "%timeit thunder_model(x); torch.cuda.synchronize()\n", + "%timeit m(x); torch.cuda.synchronize()" + ] + }, + { + "cell_type": "markdown", + "id": "7996acc7-de20-4aa5-80f0-1ab6042e2650", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "tm = thunder.jit(m)" + "So what has changed? Quite a bit!\n", + "\n", + "When we call the Thunder module, it do the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" ] }, { "cell_type": "code", - "execution_count": 54, - "id": "964e2689", + "execution_count": 6, + "id": "a6f4b77c", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", + " # x: \"cuda:0 f32[2, 2048, 4096]\" \n", + " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", + " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", + " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", + " x_fc_1 = torch.nn.functional.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", + " del t_fc_1_weight\n", + " x_fc_2 = torch.nn.functional.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", + " del x, t_fc_2_weight\n", + " [result] = nvFusion0(x_fc_1, x_fc_2)\n", + " # t9 = prims.neg(x_fc_1) # t9: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # t10 = prims.exp(t9) # t10: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # t11 = prims.add(1.0, t10) # t11: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # t12 = prims.reciprocal(t11) # t12: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # a = prims.mul(x_fc_1, t12) # a: \"cuda:0 f32[2, 2048, 11008]\"\n", + " # result = prims.mul(a, x_fc_2) # result: \"cuda:0 f32[2, 2048, 11008]\"\n", + " del x_fc_1, x_fc_2\n", + " t18 = torch.nn.functional.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", + " # t18 = ltorch.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", + " # t18 = prims.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", + " del result, t_proj_weight\n", + " return t18" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "thunder.last_traces(thunder_model)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "2ef89186-70cd-4737-9695-ed282da2a56c", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "For more detail of what is going on in this trace:\n", + "- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.\n", + "- It has recorded the tensor metadata.\n", + "- Operations have been mapped from the PyTorch functions to `thunder.torch`(aka `ltorch`) equivalents and decomposed into _primitive operations_.\n", + "- The multiplication and activation (`x = torch.nn.functional.silu(x_fc_1) * x_fc_2`have been put into one NVFuser fusion. (NVFuser here is (a particularly important) one of many optimizations, and we make it easy to add your own.) \n", + "- You can see how the parameters are obtained and the metadata is checked in the prologue - get it through `thunder.last_prologue_traces(thunder_model)[-1]`.\n", + "\n", + "You can actually see the series of traces, `last_traces` gives you a list of transformed traces in chronological order - for example the initial trace `thunder.last_traces(thunder_model)[0]` does not have the fusion yet.\n" + ] + }, + { + "cell_type": "markdown", + "id": "7749aed1", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Compiling a more complex model\n", + "\n", + "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n", + "\n", + "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d53e0c43", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "ThunderModule(\n", - " (_model): LLaMAMLP(\n", - " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + "GPT(\n", + " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n", + " (transformer): ModuleDict(\n", + " (wte): Embedding(32000, 4096)\n", + " (h): ModuleList(\n", + " (0-15): 16 x Block(\n", + " (norm_1): RMSNorm()\n", + " (attn): CausalSelfAttention(\n", + " (attn): Linear(in_features=4096, out_features=12288, bias=False)\n", + " (proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (norm_2): RMSNorm()\n", + " (mlp): LLaMAMLP(\n", + " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + " )\n", + " )\n", + " )\n", + " (ln_f): RMSNorm()\n", " )\n", ")" ] }, - "execution_count": 54, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tm" + "from lit_gpt import GPT\n", + "from thunder.tests.lit_gpt_model import Config\n", + "cfg = Config.from_name('Llama-2-7b-hf')\n", + "cfg.n_layer = 16 # fewer layers\n", + "torch.set_default_dtype(torch.bfloat16)\n", + "with torch.device('cuda'):\n", + " m = GPT(cfg)\n", + "m\n" + ] + }, + { + "cell_type": "markdown", + "id": "e536a4aa", + "metadata": {}, + "source": [ + "Again we jit our model and compare the output..." ] }, { "cell_type": "code", - "execution_count": 60, - "id": "7f4de1b3", + "execution_count": 8, + "id": "36a7be96", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 0.03125\n" + ] + } + ], + "source": [ + "thunder_model = thunder.jit(m)\n", + "\n", + "inp = torch.randint(1, m.config.vocab_size, (1, 512), device=\"cuda\")\n", + "\n", + "actual = thunder_model(inp)\n", + "expected = m(inp)\n", + "\n", + "print(\"deviation:\", (actual - expected).abs().max().item())\n" + ] + }, + { + "cell_type": "markdown", + "id": "9947e8df-cd2d-447d-90b9-ee08bb5a9fb2", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.\n", + "\n", + "Just like before, we can see the program it ran, it is a lot longer, though." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ac7e8bc9", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, { "data": { "text/plain": [ - "tensor(1.4901e-07, device='cuda:0', grad_fn=)" + "# Constructed by Delete Last Used (took 10 milliseconds)\n", + "import torch\n", + "from torch import Tensor\n", + "import torch.nn.functional\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def augmented_forward_fn(*args):\n", + " # args: \"Collection\" \n", + " t0, \\\n", + " t1, \\\n", + " t2, \\\n", + " t3, \\\n", + " t4, \\\n", + " t5, \\\n", + " t6, \\\n", + " t7, \\\n", + " t8, \\\n", + " t9, \\\n", + " t10, \\\n", + " t11, \\\n", + " t12, \\\n", + " t13, \\\n", + " t14, \\\n", + " t15, \\\n", + " t16, \\\n", + " t17, \\\n", + " t18, \\\n", + " t19, \\\n", + " t20, \\\n", + " t21, \\\n", + " t22, \\\n", + " t23, \\\n", + " t24, \\\n", + " t25, \\\n", + " t26, \\\n", + " t27, \\\n", + " t28, \\\n", + " t29, \\\n", + " t30, \\\n", + " t31, \\\n", + " t32, \\\n", + " t33, \\\n", + " t34, \\\n", + " t35, \\\n", + " t36, \\\n", + " t37, \\\n", + " t38, \\\n", + " t39, \\\n", + " t40, \\\n", + " t41, \\\n", + " t42, \\\n", + " t43, \\\n", + " t44, \\\n", + " t45, \\\n", + " t46, \\\n", + " t47, \\\n", + " t48, \\\n", + " t49, \\\n", + " t50, \\\n", + " t51, \\\n", + " t52, \\\n", + " t53, \\\n", + " t54, \\\n", + " t55, \\\n", + " t56, \\\n", + " t57, \\\n", + " t58, \\\n", + " t59, \\\n", + " t60, \\\n", + " t61, \\\n", + " t62, \\\n", + " t63, \\\n", + " t64, \\\n", + " t65, \\\n", + " t66, \\\n", + " t67, \\\n", + " t68, \\\n", + " t69, \\\n", + " t70, \\\n", + " t71, \\\n", + " t72, \\\n", + " t73, \\\n", + " t74, \\\n", + " t75, \\\n", + " t76, \\\n", + " t77, \\\n", + " t78, \\\n", + " t79, \\\n", + " t80, \\\n", + " t81, \\\n", + " t82, \\\n", + " t83, \\\n", + " t84, \\\n", + " t85, \\\n", + " t86, \\\n", + " t87, \\\n", + " t88, \\\n", + " t89, \\\n", + " t90, \\\n", + " t91, \\\n", + " t92, \\\n", + " t93, \\\n", + " t94, \\\n", + " t95, \\\n", + " t96, \\\n", + " t97, \\\n", + " t98, \\\n", + " t99, \\\n", + " t100, \\\n", + " t101, \\\n", + " t102, \\\n", + " t103, \\\n", + " t104, \\\n", + " t105, \\\n", + " t106, \\\n", + " t107, \\\n", + " t108, \\\n", + " t109, \\\n", + " t110, \\\n", + " t111, \\\n", + " t112, \\\n", + " t113, \\\n", + " t114, \\\n", + " t115, \\\n", + " t116, \\\n", + " t117, \\\n", + " = args\n", + " del args\n", + " t122 = torch.nn.functional.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t122 = ltorch.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1867 = ltorch.reshape(t0, [512]) # t1867: \"cuda:0 i64[512]\"\n", + " # t1867 = prims.reshape(t0, (512,)) # t1867: \"cuda:0 i64[512]\"\n", + " # t1868 = prims.take(t117, t1867, 0) # t1868: \"cuda:0 bf16[512, 4096]\"\n", + " # t122 = ltorch.reshape(t1868, [1, 512, 4096]) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t122 = prims.reshape(t1868, (1, 512, 4096)) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t118 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t118: \"cuda:0 f32[512, 128]\"\n", + " t119 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t119: \"cuda:0 f32[512, 128]\"\n", + " t2015 = torch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " # t2015 = ltorch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " # t2015 = prims.broadcast_in_dim(t53, [1, 4096], [1]) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " t2016 = torch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2016 = ltorch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2016 = prims.broadcast_in_dim(t2015, [1, 1, 4096], [0, 2]) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2015\n", + " t133 = Tensor.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t133 = ltorch.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t133 = prims.broadcast_in_dim(t2016, (1, 512, 4096), (0, 1, 2)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2016\n", + " t2356 = torch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " # t2356 = ltorch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " # t2356 = prims.broadcast_in_dim(t82, [1, 4096], [1]) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " t2357 = torch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2357 = ltorch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2357 = prims.broadcast_in_dim(t2356, [1, 1, 4096], [0, 2]) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2356\n", + " t1609 = Tensor.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1609 = ltorch.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1609 = prims.broadcast_in_dim(t2357, (1, 512, 4096), (0, 1, 2)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2357\n", + " t2359 = torch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " # t2359 = ltorch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " # t2359 = prims.broadcast_in_dim(t58, [1, 4096], [1]) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " t2360 = torch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2360 = ltorch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2360 = prims.broadcast_in_dim(t2359, [1, 1, 4096], [0, 2]) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2359\n", + " t1645 = Tensor.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1645 = ltorch.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1645 = prims.broadcast_in_dim(t2360, (1, 512, 4096), (0, 1, 2)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2360\n", + " t2044 = torch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " # t2044 = ltorch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " # t2044 = prims.broadcast_in_dim(t69, [1, 4096], [1]) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " t2045 = torch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2045 = ltorch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2045 = prims.broadcast_in_dim(t2044, [1, 1, 4096], [0, 2]) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2044\n", + " t205 = Tensor.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t205 = ltorch.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t205 = prims.broadcast_in_dim(t2045, (1, 512, 4096), (0, 1, 2)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2045\n", + " t2380 = torch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " # t2380 = ltorch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " # t2380 = prims.broadcast_in_dim(t83, [1, 4096], [1]) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " t2381 = torch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2381 = ltorch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2381 = prims.broadcast_in_dim(t2380, [1, 1, 4096], [0, 2]) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2380\n", + " t1717 = Tensor.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1717 = ltorch.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1717 = prims.broadcast_in_dim(t2381, (1, 512, 4096), (0, 1, 2)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2381\n", + " t2047 = torch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " # t2047 = ltorch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " # t2047 = prims.broadcast_in_dim(t60, [1, 4096], [1]) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " t2048 = torch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2048 = ltorch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2048 = prims.broadcast_in_dim(t2047, [1, 1, 4096], [0, 2]) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2047\n", + " t241 = Tensor.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t241 = ltorch.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t241 = prims.broadcast_in_dim(t2048, (1, 512, 4096), (0, 1, 2)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2048\n", + " t2383 = torch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " # t2383 = ltorch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " # t2383 = prims.broadcast_in_dim(t59, [1, 4096], [1]) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " t2384 = torch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2384 = ltorch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2384 = prims.broadcast_in_dim(t2383, [1, 1, 4096], [0, 2]) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2383\n", + " t1753 = Tensor.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1753 = ltorch.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1753 = prims.broadcast_in_dim(t2384, (1, 512, 4096), (0, 1, 2)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2384\n", + " t2068 = torch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " # t2068 = ltorch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " # t2068 = prims.broadcast_in_dim(t70, [1, 4096], [1]) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " t2069 = torch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2069 = ltorch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2069 = prims.broadcast_in_dim(t2068, [1, 1, 4096], [0, 2]) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2068\n", + " t313 = Tensor.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t313 = ltorch.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t313 = prims.broadcast_in_dim(t2069, (1, 512, 4096), (0, 1, 2)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2069\n", + " t2404 = torch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " # t2404 = ltorch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " # t2404 = prims.broadcast_in_dim(t84, [1, 4096], [1]) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " t2405 = torch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2405 = ltorch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2405 = prims.broadcast_in_dim(t2404, [1, 1, 4096], [0, 2]) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2404\n", + " t1825 = Tensor.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1825 = ltorch.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1825 = prims.broadcast_in_dim(t2405, (1, 512, 4096), (0, 1, 2)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2405\n", + " t2071 = torch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " # t2071 = ltorch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " # t2071 = prims.broadcast_in_dim(t61, [1, 4096], [1]) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " t2072 = torch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2072 = ltorch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2072 = prims.broadcast_in_dim(t2071, [1, 1, 4096], [0, 2]) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2071\n", + " t349 = Tensor.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t349 = ltorch.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t349 = prims.broadcast_in_dim(t2072, (1, 512, 4096), (0, 1, 2)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2072\n", + " t2407 = torch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " # t2407 = ltorch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " # t2407 = prims.broadcast_in_dim(t52, [1, 4096], [1]) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " t2408 = torch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2408 = ltorch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2408 = prims.broadcast_in_dim(t2407, [1, 1, 4096], [0, 2]) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2407\n", + " t1861 = Tensor.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1861 = ltorch.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1861 = prims.broadcast_in_dim(t2408, (1, 512, 4096), (0, 1, 2)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2408\n", + " t2095 = torch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " # t2095 = ltorch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " # t2095 = prims.broadcast_in_dim(t62, [1, 4096], [1]) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " t2096 = torch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2096 = ltorch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2096 = prims.broadcast_in_dim(t2095, [1, 1, 4096], [0, 2]) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2095\n", + " t457 = Tensor.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t457 = ltorch.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t457 = prims.broadcast_in_dim(t2096, (1, 512, 4096), (0, 1, 2)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2096\n", + " t2092 = torch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " # t2092 = ltorch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " # t2092 = prims.broadcast_in_dim(t71, [1, 4096], [1]) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " t2093 = torch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2093 = ltorch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2093 = prims.broadcast_in_dim(t2092, [1, 1, 4096], [0, 2]) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2092\n", + " t421 = Tensor.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t421 = ltorch.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t421 = prims.broadcast_in_dim(t2093, (1, 512, 4096), (0, 1, 2)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2093\n", + " t2116 = torch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " # t2116 = ltorch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " # t2116 = prims.broadcast_in_dim(t72, [1, 4096], [1]) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " t2117 = torch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2117 = ltorch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2117 = prims.broadcast_in_dim(t2116, [1, 1, 4096], [0, 2]) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2116\n", + " t529 = Tensor.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t529 = ltorch.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t529 = prims.broadcast_in_dim(t2117, (1, 512, 4096), (0, 1, 2)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2117\n", + " t2119 = torch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " # t2119 = ltorch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " # t2119 = prims.broadcast_in_dim(t63, [1, 4096], [1]) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " t2120 = torch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2120 = ltorch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2120 = prims.broadcast_in_dim(t2119, [1, 1, 4096], [0, 2]) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2119\n", + " t565 = Tensor.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t565 = ltorch.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t565 = prims.broadcast_in_dim(t2120, (1, 512, 4096), (0, 1, 2)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2120\n", + " t2140 = torch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " # t2140 = ltorch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " # t2140 = prims.broadcast_in_dim(t73, [1, 4096], [1]) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " t2141 = torch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2141 = ltorch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2141 = prims.broadcast_in_dim(t2140, [1, 1, 4096], [0, 2]) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2140\n", + " t637 = Tensor.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t637 = ltorch.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t637 = prims.broadcast_in_dim(t2141, (1, 512, 4096), (0, 1, 2)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2141\n", + " t2143 = torch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " # t2143 = ltorch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " # t2143 = prims.broadcast_in_dim(t64, [1, 4096], [1]) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " t2144 = torch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2144 = ltorch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2144 = prims.broadcast_in_dim(t2143, [1, 1, 4096], [0, 2]) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2143\n", + " t673 = Tensor.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t673 = ltorch.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t673 = prims.broadcast_in_dim(t2144, (1, 512, 4096), (0, 1, 2)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2144\n", + " t2164 = torch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " # t2164 = ltorch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " # t2164 = prims.broadcast_in_dim(t74, [1, 4096], [1]) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " t2165 = torch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2165 = ltorch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2165 = prims.broadcast_in_dim(t2164, [1, 1, 4096], [0, 2]) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2164\n", + " t745 = Tensor.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t745 = ltorch.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t745 = prims.broadcast_in_dim(t2165, (1, 512, 4096), (0, 1, 2)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2165\n", + " t2167 = torch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " # t2167 = ltorch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " # t2167 = prims.broadcast_in_dim(t65, [1, 4096], [1]) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " t2168 = torch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2168 = ltorch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2168 = prims.broadcast_in_dim(t2167, [1, 1, 4096], [0, 2]) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2167\n", + " t781 = Tensor.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t781 = ltorch.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t781 = prims.broadcast_in_dim(t2168, (1, 512, 4096), (0, 1, 2)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2168\n", + " t2188 = torch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " # t2188 = ltorch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " # t2188 = prims.broadcast_in_dim(t75, [1, 4096], [1]) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " t2189 = torch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2189 = ltorch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2189 = prims.broadcast_in_dim(t2188, [1, 1, 4096], [0, 2]) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2188\n", + " t853 = Tensor.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t853 = ltorch.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t853 = prims.broadcast_in_dim(t2189, (1, 512, 4096), (0, 1, 2)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2189\n", + " t2191 = torch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " # t2191 = ltorch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " # t2191 = prims.broadcast_in_dim(t66, [1, 4096], [1]) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " t2192 = torch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2192 = ltorch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2192 = prims.broadcast_in_dim(t2191, [1, 1, 4096], [0, 2]) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2191\n", + " t889 = Tensor.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t889 = ltorch.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t889 = prims.broadcast_in_dim(t2192, (1, 512, 4096), (0, 1, 2)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2192\n", + " t2212 = torch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " # t2212 = ltorch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " # t2212 = prims.broadcast_in_dim(t76, [1, 4096], [1]) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " t2213 = torch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2213 = ltorch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2213 = prims.broadcast_in_dim(t2212, [1, 1, 4096], [0, 2]) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2212\n", + " t961 = Tensor.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t961 = ltorch.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t961 = prims.broadcast_in_dim(t2213, (1, 512, 4096), (0, 1, 2)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2213\n", + " t2215 = torch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " # t2215 = ltorch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " # t2215 = prims.broadcast_in_dim(t67, [1, 4096], [1]) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " t2216 = torch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2216 = ltorch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2216 = prims.broadcast_in_dim(t2215, [1, 1, 4096], [0, 2]) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2215\n", + " t997 = Tensor.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t997 = ltorch.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t997 = prims.broadcast_in_dim(t2216, (1, 512, 4096), (0, 1, 2)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2216\n", + " t2236 = torch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " # t2236 = ltorch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " # t2236 = prims.broadcast_in_dim(t77, [1, 4096], [1]) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " t2237 = torch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2237 = ltorch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2237 = prims.broadcast_in_dim(t2236, [1, 1, 4096], [0, 2]) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2236\n", + " t1069 = Tensor.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1069 = ltorch.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1069 = prims.broadcast_in_dim(t2237, (1, 512, 4096), (0, 1, 2)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2237\n", + " t2239 = torch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " # t2239 = ltorch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " # t2239 = prims.broadcast_in_dim(t68, [1, 4096], [1]) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " t2240 = torch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2240 = ltorch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2240 = prims.broadcast_in_dim(t2239, [1, 1, 4096], [0, 2]) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2239\n", + " t1105 = Tensor.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1105 = ltorch.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1105 = prims.broadcast_in_dim(t2240, (1, 512, 4096), (0, 1, 2)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2240\n", + " t2260 = torch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " # t2260 = ltorch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " # t2260 = prims.broadcast_in_dim(t78, [1, 4096], [1]) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " t2261 = torch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2261 = ltorch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2261 = prims.broadcast_in_dim(t2260, [1, 1, 4096], [0, 2]) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2260\n", + " t1177 = Tensor.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1177 = ltorch.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1177 = prims.broadcast_in_dim(t2261, (1, 512, 4096), (0, 1, 2)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2261\n", + " t2263 = torch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " # t2263 = ltorch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " # t2263 = prims.broadcast_in_dim(t54, [1, 4096], [1]) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " t2264 = torch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2264 = ltorch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2264 = prims.broadcast_in_dim(t2263, [1, 1, 4096], [0, 2]) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2263\n", + " t1213 = Tensor.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1213 = ltorch.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1213 = prims.broadcast_in_dim(t2264, (1, 512, 4096), (0, 1, 2)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2264\n", + " t2284 = torch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " # t2284 = ltorch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " # t2284 = prims.broadcast_in_dim(t79, [1, 4096], [1]) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " t2285 = torch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2285 = ltorch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2285 = prims.broadcast_in_dim(t2284, [1, 1, 4096], [0, 2]) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2284\n", + " t1285 = Tensor.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1285 = ltorch.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1285 = prims.broadcast_in_dim(t2285, (1, 512, 4096), (0, 1, 2)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2285\n", + " t2287 = torch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " # t2287 = ltorch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " # t2287 = prims.broadcast_in_dim(t55, [1, 4096], [1]) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " t2288 = torch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2288 = ltorch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2288 = prims.broadcast_in_dim(t2287, [1, 1, 4096], [0, 2]) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2287\n", + " t1321 = Tensor.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1321 = ltorch.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1321 = prims.broadcast_in_dim(t2288, (1, 512, 4096), (0, 1, 2)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2288\n", + " t2308 = torch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " # t2308 = ltorch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " # t2308 = prims.broadcast_in_dim(t80, [1, 4096], [1]) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " t2309 = torch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2309 = ltorch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2309 = prims.broadcast_in_dim(t2308, [1, 1, 4096], [0, 2]) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2308\n", + " t1393 = Tensor.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1393 = ltorch.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1393 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2309\n", + " t2311 = torch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " # t2311 = ltorch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " # t2311 = prims.broadcast_in_dim(t56, [1, 4096], [1]) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " t2312 = torch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2312 = ltorch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2312 = prims.broadcast_in_dim(t2311, [1, 1, 4096], [0, 2]) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2311\n", + " t1429 = Tensor.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1429 = ltorch.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1429 = prims.broadcast_in_dim(t2312, (1, 512, 4096), (0, 1, 2)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2312\n", + " t2332 = torch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " # t2332 = ltorch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " # t2332 = prims.broadcast_in_dim(t81, [1, 4096], [1]) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " t2333 = torch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2333 = ltorch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2333 = prims.broadcast_in_dim(t2332, [1, 1, 4096], [0, 2]) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2332\n", + " t1501 = Tensor.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1501 = ltorch.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1501 = prims.broadcast_in_dim(t2333, (1, 512, 4096), (0, 1, 2)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2333\n", + " t2335 = torch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " # t2335 = ltorch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " # t2335 = prims.broadcast_in_dim(t57, [1, 4096], [1]) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " t2336 = torch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2336 = ltorch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2336 = prims.broadcast_in_dim(t2335, [1, 1, 4096], [0, 2]) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2335\n", + " t1537 = Tensor.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1537 = ltorch.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1537 = prims.broadcast_in_dim(t2336, (1, 512, 4096), (0, 1, 2)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2336\n", + " t2036 = torch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2036 = ltorch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2036 = prims.broadcast_in_dim(t118, [1, 512, 128], [1, 2]) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " del t118\n", + " t2037 = torch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2037 = ltorch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2037 = prims.broadcast_in_dim(t2036, [1, 1, 512, 128], [0, 2, 3]) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t2036\n", + " t154 = Tensor.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t154 = ltorch.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t154 = prims.broadcast_in_dim(t2037, (1, 32, 512, 128), (0, 1, 2, 3)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t2037\n", + " t2039 = torch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2039 = ltorch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2039 = prims.broadcast_in_dim(t119, [1, 512, 128], [1, 2]) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " del t119\n", + " t2040 = torch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2040 = ltorch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2040 = prims.broadcast_in_dim(t2039, [1, 1, 512, 128], [0, 2, 3]) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t2039\n", + " t157 = Tensor.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t157 = ltorch.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t157 = prims.broadcast_in_dim(t2040, (1, 32, 512, 128), (0, 1, 2, 3)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t2040\n", + " [t129, t137] = nvFusion0(t122, t133)\n", + " # t123 = prims.convert_element_type(t122, dtypes.float32) # t123: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t124 = prims.mul(t123, t123) # t124: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t125 = prims.sum(t124, (2,)) # t125: \"cuda:0 f32[1, 512]\"\n", + " # t126 = prims.broadcast_in_dim(t125, [1, 512, 1], [0, 1]) # t126: \"cuda:0 f32[1, 512, 1]\"\n", + " # t127 = prims.div(t126, 4096.0) # t127: \"cuda:0 f32[1, 512, 1]\"\n", + " # t128 = prims.add(t127, 1e-05) # t128: \"cuda:0 f32[1, 512, 1]\"\n", + " # t129 = prims.rsqrt(t128) # t129: \"cuda:0 f32[1, 512, 1]\"\n", + " # t130 = prims.broadcast_in_dim(t129, (1, 512, 4096), (0, 1, 2)) # t130: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t131 = prims.mul(t123, t130) # t131: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t135 = prims.convert_element_type(t133, dtypes.float32) # t135: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t136 = prims.mul(t131, t135) # t136: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t138 = torch.nn.functional.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t138 = ltorch.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t138 = prims.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t139 = torch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t139 = ltorch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t139 = prims.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t138\n", + " t140 = torch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t140 = ltorch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t140 = prims.transpose(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t139\n", + " (t141, t142, t143) = torch.split(t140, (1, 1, 1), 2)\n", + " # (t141, t142, t143) = ltorch.split(t140, (1, 1, 1), 2)\n", + " # t141 = prims.slice_prim(t140, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t141: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t142 = prims.slice_prim(t140, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t142: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t143 = prims.slice_prim(t140, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t143: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t140\n", + " t144 = torch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t144 = ltorch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t144 = prims.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t141\n", + " t145 = torch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t145 = ltorch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t145 = prims.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t142\n", + " t146 = torch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t146 = ltorch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t146 = prims.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t143\n", + " t147 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t147: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t162 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t162: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t177 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t177: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t144\n", + " t179 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t179: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t145\n", + " t149 = torch_slice_prim_impl(t147, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t149: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t148 = torch_slice_prim_impl(t147, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t148: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t163 = torch_slice_prim_impl(t162, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t163: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t164 = torch_slice_prim_impl(t162, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t164: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t152, t167] = nvFusion1(t147, t149, t162, t164)\n", + " # t150 = prims.convert_element_type(t149, dtypes.float32) # t150: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t151 = prims.neg(t150) # t151: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t152 = prims.convert_element_type(t151, dtypes.bfloat16) # t152: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t165 = prims.convert_element_type(t164, dtypes.float32) # t165: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t166 = prims.neg(t165) # t166: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t167 = prims.convert_element_type(t166, dtypes.bfloat16) # t167: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t149, t164\n", + " t168 = torch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t168 = ltorch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t168 = prims.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t167, t163\n", + " t153 = torch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t153 = ltorch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t153 = prims.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t152, t148\n", + " [t161, t176] = nvFusion2(t147, t153, t154, t157, t162, t168)\n", + " # t155 = prims.convert_element_type(t147, dtypes.float32) # t155: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t170 = prims.convert_element_type(t162, dtypes.float32) # t170: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t156 = prims.mul(t155, t154) # t156: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t158 = prims.convert_element_type(t153, dtypes.float32) # t158: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t159 = prims.mul(t158, t157) # t159: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t160 = prims.add(t156, t159) # t160: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t161 = prims.convert_element_type(t160, dtypes.bfloat16) # t161: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t171 = prims.mul(t170, t154) # t171: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t173 = prims.convert_element_type(t168, dtypes.float32) # t173: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t174 = prims.mul(t173, t157) # t174: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t175 = prims.add(t171, t174) # t175: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t176 = prims.convert_element_type(t175, dtypes.bfloat16) # t176: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t147, t153, t162, t168\n", + " t178 = torch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t178 = ltorch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t178 = prims.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t161, t177\n", + " t180 = torch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t180 = ltorch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t180 = prims.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t176, t179\n", + " (t181, t182, t183, t184, _, _, t185, t186, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t178, t180, t146, 0.0, True, scale=0.08838834764831843)\n", + " t188 = torch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t188 = ltorch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t188 = prims.transpose(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t189 = torch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t189 = ltorch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t189 = prims.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t188\n", + " t190 = torch.nn.functional.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t190 = ltorch.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t190 = prims.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t194, t201, t209] = nvFusion3(t122, t190, t205)\n", + " # t191 = prims.convert_element_type(t190, dtypes.float32) # t191: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t192 = prims.convert_element_type(t122, dtypes.float32) # t192: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t193 = prims.add(t191, t192) # t193: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t194 = prims.convert_element_type(t193, dtypes.bfloat16) # t194: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t196 = prims.mul(t193, t193) # t196: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t197 = prims.sum(t196, (2,)) # t197: \"cuda:0 f32[1, 512]\"\n", + " # t198 = prims.broadcast_in_dim(t197, [1, 512, 1], [0, 1]) # t198: \"cuda:0 f32[1, 512, 1]\"\n", + " # t199 = prims.div(t198, 4096.0) # t199: \"cuda:0 f32[1, 512, 1]\"\n", + " # t200 = prims.add(t199, 1e-05) # t200: \"cuda:0 f32[1, 512, 1]\"\n", + " # t201 = prims.rsqrt(t200) # t201: \"cuda:0 f32[1, 512, 1]\"\n", + " # t202 = prims.broadcast_in_dim(t201, (1, 512, 4096), (0, 1, 2)) # t202: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t203 = prims.mul(t193, t202) # t203: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t207 = prims.convert_element_type(t205, dtypes.float32) # t207: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t208 = prims.mul(t203, t207) # t208: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t209 = prims.convert_element_type(t208, dtypes.bfloat16) # t209: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t210 = torch.nn.functional.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t210 = ltorch.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t210 = prims.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t211 = torch.nn.functional.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t211 = ltorch.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t211 = prims.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t225] = nvFusion4(t210, t211)\n", + " # t212 = prims.convert_element_type(t210, dtypes.float32) # t212: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t214 = prims.exp(t213) # t214: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t215 = prims.add(1.0, t214) # t215: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t216 = prims.reciprocal(t215) # t216: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t220 = prims.mul(t212, t216) # t220: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t223 = prims.convert_element_type(t211, dtypes.float32) # t223: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t224 = prims.mul(t220, t223) # t224: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t225 = prims.convert_element_type(t224, dtypes.bfloat16) # t225: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t226 = torch.nn.functional.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t226 = ltorch.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t226 = prims.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t230, t237, t245] = nvFusion5(t194, t226, t241)\n", + " # t228 = prims.convert_element_type(t194, dtypes.float32) # t228: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t227 = prims.convert_element_type(t226, dtypes.float32) # t227: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t229 = prims.add(t227, t228) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t232 = prims.mul(t229, t229) # t232: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t233 = prims.sum(t232, (2,)) # t233: \"cuda:0 f32[1, 512]\"\n", + " # t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1]) # t234: \"cuda:0 f32[1, 512, 1]\"\n", + " # t235 = prims.div(t234, 4096.0) # t235: \"cuda:0 f32[1, 512, 1]\"\n", + " # t236 = prims.add(t235, 1e-05) # t236: \"cuda:0 f32[1, 512, 1]\"\n", + " # t237 = prims.rsqrt(t236) # t237: \"cuda:0 f32[1, 512, 1]\"\n", + " # t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2)) # t238: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t239 = prims.mul(t229, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t243 = prims.convert_element_type(t241, dtypes.float32) # t243: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t244 = prims.mul(t239, t243) # t244: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t245 = prims.convert_element_type(t244, dtypes.bfloat16) # t245: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t246 = torch.nn.functional.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t246 = ltorch.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t246 = prims.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t247 = torch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t247 = ltorch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t247 = prims.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t246\n", + " t248 = torch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t248 = ltorch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t248 = prims.transpose(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t247\n", + " (t249, t250, t251) = torch.split(t248, (1, 1, 1), 2)\n", + " # (t249, t250, t251) = ltorch.split(t248, (1, 1, 1), 2)\n", + " # t249 = prims.slice_prim(t248, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t249: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t250 = prims.slice_prim(t248, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t250: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t251 = prims.slice_prim(t248, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t251: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t248\n", + " t252 = torch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t252 = ltorch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t252 = prims.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t249\n", + " t253 = torch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t253 = ltorch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t253 = prims.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t250\n", + " t254 = torch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t254 = ltorch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t254 = prims.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t251\n", + " t285 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t285: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t287 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t287: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t255 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t255: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t252\n", + " t270 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t270: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t253\n", + " t256 = torch_slice_prim_impl(t255, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t256: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t257 = torch_slice_prim_impl(t255, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t257: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t272 = torch_slice_prim_impl(t270, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t272: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t271 = torch_slice_prim_impl(t270, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t271: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t260, t275] = nvFusion6(t255, t257, t270, t272)\n", + " # t258 = prims.convert_element_type(t257, dtypes.float32) # t258: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t259 = prims.neg(t258) # t259: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t273 = prims.convert_element_type(t272, dtypes.float32) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t274 = prims.neg(t273) # t274: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t275 = prims.convert_element_type(t274, dtypes.bfloat16) # t275: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t257, t272\n", + " t261 = torch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t261 = ltorch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t261 = prims.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t260, t256\n", + " t276 = torch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t276 = ltorch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t276 = prims.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t275, t271\n", + " [t269, t284] = nvFusion7(t154, t157, t255, t261, t270, t276)\n", + " # t263 = prims.convert_element_type(t255, dtypes.float32) # t263: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t278 = prims.convert_element_type(t270, dtypes.float32) # t278: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t264 = prims.mul(t263, t154) # t264: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t266 = prims.convert_element_type(t261, dtypes.float32) # t266: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t267 = prims.mul(t266, t157) # t267: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t268 = prims.add(t264, t267) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t279 = prims.mul(t278, t154) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t281 = prims.convert_element_type(t276, dtypes.float32) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t282 = prims.mul(t281, t157) # t282: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t283 = prims.add(t279, t282) # t283: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t284 = prims.convert_element_type(t283, dtypes.bfloat16) # t284: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t255, t261, t270, t276\n", + " t288 = torch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t288 = ltorch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t288 = prims.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t284, t287\n", + " t286 = torch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t286 = ltorch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t286 = prims.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t269, t285\n", + " (t289, t290, t291, t292, _, _, t293, t294, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t286, t288, t254, 0.0, True, scale=0.08838834764831843)\n", + " t296 = torch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t296 = ltorch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t296 = prims.transpose(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t297 = torch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t297 = ltorch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t297 = prims.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t296\n", + " t298 = torch.nn.functional.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t298 = ltorch.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t298 = prims.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t302, t309, t317] = nvFusion8(t230, t298, t313)\n", + " # t300 = prims.convert_element_type(t230, dtypes.float32) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t299 = prims.convert_element_type(t298, dtypes.float32) # t299: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t301 = prims.add(t299, t300) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t302 = prims.convert_element_type(t301, dtypes.bfloat16) # t302: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t304 = prims.mul(t301, t301) # t304: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t305 = prims.sum(t304, (2,)) # t305: \"cuda:0 f32[1, 512]\"\n", + " # t306 = prims.broadcast_in_dim(t305, [1, 512, 1], [0, 1]) # t306: \"cuda:0 f32[1, 512, 1]\"\n", + " # t307 = prims.div(t306, 4096.0) # t307: \"cuda:0 f32[1, 512, 1]\"\n", + " # t308 = prims.add(t307, 1e-05) # t308: \"cuda:0 f32[1, 512, 1]\"\n", + " # t309 = prims.rsqrt(t308) # t309: \"cuda:0 f32[1, 512, 1]\"\n", + " # t310 = prims.broadcast_in_dim(t309, (1, 512, 4096), (0, 1, 2)) # t310: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t311 = prims.mul(t301, t310) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t315 = prims.convert_element_type(t313, dtypes.float32) # t315: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t316 = prims.mul(t311, t315) # t316: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t317 = prims.convert_element_type(t316, dtypes.bfloat16) # t317: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t318 = torch.nn.functional.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t318 = ltorch.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t318 = prims.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t319 = torch.nn.functional.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t319 = ltorch.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t319 = prims.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t333] = nvFusion9(t318, t319)\n", + " # t320 = prims.convert_element_type(t318, dtypes.float32) # t320: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t321 = prims.neg(t320) # t321: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t322 = prims.exp(t321) # t322: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t323 = prims.add(1.0, t322) # t323: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t324 = prims.reciprocal(t323) # t324: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t328 = prims.mul(t320, t324) # t328: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t331 = prims.convert_element_type(t319, dtypes.float32) # t331: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t332 = prims.mul(t328, t331) # t332: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t333 = prims.convert_element_type(t332, dtypes.bfloat16) # t333: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t334 = torch.nn.functional.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t334 = ltorch.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t334 = prims.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t338, t345, t353] = nvFusion10(t302, t334, t349)\n", + " # t336 = prims.convert_element_type(t302, dtypes.float32) # t336: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t335 = prims.convert_element_type(t334, dtypes.float32) # t335: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t337 = prims.add(t335, t336) # t337: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t340 = prims.mul(t337, t337) # t340: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t341 = prims.sum(t340, (2,)) # t341: \"cuda:0 f32[1, 512]\"\n", + " # t342 = prims.broadcast_in_dim(t341, [1, 512, 1], [0, 1]) # t342: \"cuda:0 f32[1, 512, 1]\"\n", + " # t343 = prims.div(t342, 4096.0) # t343: \"cuda:0 f32[1, 512, 1]\"\n", + " # t344 = prims.add(t343, 1e-05) # t344: \"cuda:0 f32[1, 512, 1]\"\n", + " # t345 = prims.rsqrt(t344) # t345: \"cuda:0 f32[1, 512, 1]\"\n", + " # t346 = prims.broadcast_in_dim(t345, (1, 512, 4096), (0, 1, 2)) # t346: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t347 = prims.mul(t337, t346) # t347: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t351 = prims.convert_element_type(t349, dtypes.float32) # t351: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t352 = prims.mul(t347, t351) # t352: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t354 = torch.nn.functional.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t354 = ltorch.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t354 = prims.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t355 = torch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t355 = ltorch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t355 = prims.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t354\n", + " t356 = torch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t356 = ltorch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t356 = prims.transpose(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t355\n", + " (t357, t358, t359) = torch.split(t356, (1, 1, 1), 2)\n", + " # (t357, t358, t359) = ltorch.split(t356, (1, 1, 1), 2)\n", + " # t357 = prims.slice_prim(t356, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t357: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t358 = prims.slice_prim(t356, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t358: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t359 = prims.slice_prim(t356, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t359: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t356\n", + " t360 = torch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t360 = ltorch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t360 = prims.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t357\n", + " t361 = torch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t361 = ltorch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t361 = prims.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t358\n", + " t362 = torch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t362 = ltorch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t362 = prims.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t359\n", + " t363 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t363: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t378 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t378: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t393 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t393: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t360\n", + " t395 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t395: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t361\n", + " t364 = torch_slice_prim_impl(t363, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t364: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t365 = torch_slice_prim_impl(t363, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t365: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t379 = torch_slice_prim_impl(t378, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t379: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t380 = torch_slice_prim_impl(t378, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t380: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t368, t383] = nvFusion11(t363, t365, t378, t380)\n", + " # t366 = prims.convert_element_type(t365, dtypes.float32) # t366: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t367 = prims.neg(t366) # t367: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t368 = prims.convert_element_type(t367, dtypes.bfloat16) # t368: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t381 = prims.convert_element_type(t380, dtypes.float32) # t381: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t382 = prims.neg(t381) # t382: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t383 = prims.convert_element_type(t382, dtypes.bfloat16) # t383: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t365, t380\n", + " t369 = torch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t369 = ltorch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t369 = prims.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t368, t364\n", + " t384 = torch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t384 = ltorch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t384 = prims.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t383, t379\n", + " [t377, t392] = nvFusion12(t154, t157, t363, t369, t378, t384)\n", + " # t371 = prims.convert_element_type(t363, dtypes.float32) # t371: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t386 = prims.convert_element_type(t378, dtypes.float32) # t386: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t372 = prims.mul(t371, t154) # t372: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t374 = prims.convert_element_type(t369, dtypes.float32) # t374: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t375 = prims.mul(t374, t157) # t375: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t376 = prims.add(t372, t375) # t376: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t377 = prims.convert_element_type(t376, dtypes.bfloat16) # t377: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t387 = prims.mul(t386, t154) # t387: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t389 = prims.convert_element_type(t384, dtypes.float32) # t389: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t390 = prims.mul(t389, t157) # t390: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t391 = prims.add(t387, t390) # t391: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t363, t369, t378, t384\n", + " t394 = torch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t394 = ltorch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t394 = prims.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t377, t393\n", + " t396 = torch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t396 = ltorch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t396 = prims.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t392, t395\n", + " (t397, t398, t399, t400, _, _, t401, t402, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t394, t396, t362, 0.0, True, scale=0.08838834764831843)\n", + " t404 = torch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t404 = ltorch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t404 = prims.transpose(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t405 = torch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t405 = ltorch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t405 = prims.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t404\n", + " t406 = torch.nn.functional.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t406 = ltorch.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t406 = prims.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t410, t417, t425] = nvFusion13(t338, t406, t421)\n", + " # t408 = prims.convert_element_type(t338, dtypes.float32) # t408: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t407 = prims.convert_element_type(t406, dtypes.float32) # t407: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t409 = prims.add(t407, t408) # t409: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t410 = prims.convert_element_type(t409, dtypes.bfloat16) # t410: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t412 = prims.mul(t409, t409) # t412: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t413 = prims.sum(t412, (2,)) # t413: \"cuda:0 f32[1, 512]\"\n", + " # t414 = prims.broadcast_in_dim(t413, [1, 512, 1], [0, 1]) # t414: \"cuda:0 f32[1, 512, 1]\"\n", + " # t415 = prims.div(t414, 4096.0) # t415: \"cuda:0 f32[1, 512, 1]\"\n", + " # t416 = prims.add(t415, 1e-05) # t416: \"cuda:0 f32[1, 512, 1]\"\n", + " # t417 = prims.rsqrt(t416) # t417: \"cuda:0 f32[1, 512, 1]\"\n", + " # t418 = prims.broadcast_in_dim(t417, (1, 512, 4096), (0, 1, 2)) # t418: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t419 = prims.mul(t409, t418) # t419: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t423 = prims.convert_element_type(t421, dtypes.float32) # t423: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t424 = prims.mul(t419, t423) # t424: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t425 = prims.convert_element_type(t424, dtypes.bfloat16) # t425: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t426 = torch.nn.functional.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t426 = ltorch.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t426 = prims.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t427 = torch.nn.functional.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t427 = ltorch.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t427 = prims.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t441] = nvFusion14(t426, t427)\n", + " # t428 = prims.convert_element_type(t426, dtypes.float32) # t428: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t429 = prims.neg(t428) # t429: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t430 = prims.exp(t429) # t430: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t431 = prims.add(1.0, t430) # t431: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t432 = prims.reciprocal(t431) # t432: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t436 = prims.mul(t428, t432) # t436: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t439 = prims.convert_element_type(t427, dtypes.float32) # t439: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t440 = prims.mul(t436, t439) # t440: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t441 = prims.convert_element_type(t440, dtypes.bfloat16) # t441: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t442 = torch.nn.functional.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t442 = ltorch.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t442 = prims.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t446, t453, t461] = nvFusion15(t410, t442, t457)\n", + " # t444 = prims.convert_element_type(t410, dtypes.float32) # t444: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t443 = prims.convert_element_type(t442, dtypes.float32) # t443: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t445 = prims.add(t443, t444) # t445: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t446 = prims.convert_element_type(t445, dtypes.bfloat16) # t446: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t448 = prims.mul(t445, t445) # t448: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t449 = prims.sum(t448, (2,)) # t449: \"cuda:0 f32[1, 512]\"\n", + " # t450 = prims.broadcast_in_dim(t449, [1, 512, 1], [0, 1]) # t450: \"cuda:0 f32[1, 512, 1]\"\n", + " # t451 = prims.div(t450, 4096.0) # t451: \"cuda:0 f32[1, 512, 1]\"\n", + " # t452 = prims.add(t451, 1e-05) # t452: \"cuda:0 f32[1, 512, 1]\"\n", + " # t453 = prims.rsqrt(t452) # t453: \"cuda:0 f32[1, 512, 1]\"\n", + " # t454 = prims.broadcast_in_dim(t453, (1, 512, 4096), (0, 1, 2)) # t454: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t455 = prims.mul(t445, t454) # t455: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t459 = prims.convert_element_type(t457, dtypes.float32) # t459: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t460 = prims.mul(t455, t459) # t460: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t461 = prims.convert_element_type(t460, dtypes.bfloat16) # t461: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t462 = torch.nn.functional.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t462 = ltorch.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t462 = prims.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t463 = torch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t463 = ltorch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t463 = prims.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t462\n", + " t464 = torch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t464 = ltorch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t464 = prims.transpose(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t463\n", + " (t465, t466, t467) = torch.split(t464, (1, 1, 1), 2)\n", + " # (t465, t466, t467) = ltorch.split(t464, (1, 1, 1), 2)\n", + " # t465 = prims.slice_prim(t464, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t465: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t466 = prims.slice_prim(t464, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t466: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t467 = prims.slice_prim(t464, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t467: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t464\n", + " t468 = torch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t468 = ltorch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t468 = prims.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t465\n", + " t469 = torch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t469 = ltorch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t469 = prims.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t466\n", + " t470 = torch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t470 = ltorch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t470 = prims.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t467\n", + " t471 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t471: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t486 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t486: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t501 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t501: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t468\n", + " t503 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t503: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t469\n", + " t472 = torch_slice_prim_impl(t471, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t472: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t473 = torch_slice_prim_impl(t471, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t473: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t487 = torch_slice_prim_impl(t486, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t487: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t488 = torch_slice_prim_impl(t486, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t488: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t476, t491] = nvFusion16(t471, t473, t486, t488)\n", + " # t474 = prims.convert_element_type(t473, dtypes.float32) # t474: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t475 = prims.neg(t474) # t475: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t476 = prims.convert_element_type(t475, dtypes.bfloat16) # t476: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t489 = prims.convert_element_type(t488, dtypes.float32) # t489: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t490 = prims.neg(t489) # t490: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t491 = prims.convert_element_type(t490, dtypes.bfloat16) # t491: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t473, t488\n", + " t477 = torch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t477 = ltorch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t477 = prims.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t476, t472\n", + " t492 = torch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t492 = ltorch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t492 = prims.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t491, t487\n", + " [t485, t500] = nvFusion17(t154, t157, t471, t477, t486, t492)\n", + " # t479 = prims.convert_element_type(t471, dtypes.float32) # t479: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t494 = prims.convert_element_type(t486, dtypes.float32) # t494: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t480 = prims.mul(t479, t154) # t480: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t482 = prims.convert_element_type(t477, dtypes.float32) # t482: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t483 = prims.mul(t482, t157) # t483: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t484 = prims.add(t480, t483) # t484: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t485 = prims.convert_element_type(t484, dtypes.bfloat16) # t485: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t495 = prims.mul(t494, t154) # t495: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t497 = prims.convert_element_type(t492, dtypes.float32) # t497: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t498 = prims.mul(t497, t157) # t498: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t499 = prims.add(t495, t498) # t499: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t500 = prims.convert_element_type(t499, dtypes.bfloat16) # t500: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t471, t477, t486, t492\n", + " t502 = torch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t502 = ltorch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t502 = prims.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t485, t501\n", + " t504 = torch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t504 = ltorch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t504 = prims.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t500, t503\n", + " (t505, t506, t507, t508, _, _, t509, t510, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t502, t504, t470, 0.0, True, scale=0.08838834764831843)\n", + " t512 = torch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t512 = ltorch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t512 = prims.transpose(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t513 = torch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t513 = ltorch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t513 = prims.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t512\n", + " t514 = torch.nn.functional.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t514 = ltorch.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t514 = prims.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t518, t525, t533] = nvFusion18(t446, t514, t529)\n", + " # t516 = prims.convert_element_type(t446, dtypes.float32) # t516: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t515 = prims.convert_element_type(t514, dtypes.float32) # t515: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t517 = prims.add(t515, t516) # t517: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t518 = prims.convert_element_type(t517, dtypes.bfloat16) # t518: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t520 = prims.mul(t517, t517) # t520: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t521 = prims.sum(t520, (2,)) # t521: \"cuda:0 f32[1, 512]\"\n", + " # t522 = prims.broadcast_in_dim(t521, [1, 512, 1], [0, 1]) # t522: \"cuda:0 f32[1, 512, 1]\"\n", + " # t523 = prims.div(t522, 4096.0) # t523: \"cuda:0 f32[1, 512, 1]\"\n", + " # t524 = prims.add(t523, 1e-05) # t524: \"cuda:0 f32[1, 512, 1]\"\n", + " # t525 = prims.rsqrt(t524) # t525: \"cuda:0 f32[1, 512, 1]\"\n", + " # t526 = prims.broadcast_in_dim(t525, (1, 512, 4096), (0, 1, 2)) # t526: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t527 = prims.mul(t517, t526) # t527: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t531 = prims.convert_element_type(t529, dtypes.float32) # t531: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t532 = prims.mul(t527, t531) # t532: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t534 = torch.nn.functional.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t534 = ltorch.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t534 = prims.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t535 = torch.nn.functional.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t535 = ltorch.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t535 = prims.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t549] = nvFusion19(t534, t535)\n", + " # t536 = prims.convert_element_type(t534, dtypes.float32) # t536: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t537 = prims.neg(t536) # t537: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t538 = prims.exp(t537) # t538: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t539 = prims.add(1.0, t538) # t539: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t540 = prims.reciprocal(t539) # t540: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t544 = prims.mul(t536, t540) # t544: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t547 = prims.convert_element_type(t535, dtypes.float32) # t547: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t548 = prims.mul(t544, t547) # t548: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t549 = prims.convert_element_type(t548, dtypes.bfloat16) # t549: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t550 = torch.nn.functional.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t550 = ltorch.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t550 = prims.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t554, t561, t569] = nvFusion20(t518, t550, t565)\n", + " # t552 = prims.convert_element_type(t518, dtypes.float32) # t552: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t551 = prims.convert_element_type(t550, dtypes.float32) # t551: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t553 = prims.add(t551, t552) # t553: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t556 = prims.mul(t553, t553) # t556: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t557 = prims.sum(t556, (2,)) # t557: \"cuda:0 f32[1, 512]\"\n", + " # t558 = prims.broadcast_in_dim(t557, [1, 512, 1], [0, 1]) # t558: \"cuda:0 f32[1, 512, 1]\"\n", + " # t559 = prims.div(t558, 4096.0) # t559: \"cuda:0 f32[1, 512, 1]\"\n", + " # t560 = prims.add(t559, 1e-05) # t560: \"cuda:0 f32[1, 512, 1]\"\n", + " # t561 = prims.rsqrt(t560) # t561: \"cuda:0 f32[1, 512, 1]\"\n", + " # t562 = prims.broadcast_in_dim(t561, (1, 512, 4096), (0, 1, 2)) # t562: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t563 = prims.mul(t553, t562) # t563: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t567 = prims.convert_element_type(t565, dtypes.float32) # t567: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t568 = prims.mul(t563, t567) # t568: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t569 = prims.convert_element_type(t568, dtypes.bfloat16) # t569: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t570 = torch.nn.functional.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t570 = ltorch.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t570 = prims.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t571 = torch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t571 = ltorch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t571 = prims.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t570\n", + " t572 = torch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t572 = ltorch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t572 = prims.transpose(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t571\n", + " (t573, t574, t575) = torch.split(t572, (1, 1, 1), 2)\n", + " # (t573, t574, t575) = ltorch.split(t572, (1, 1, 1), 2)\n", + " # t573 = prims.slice_prim(t572, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t573: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t574 = prims.slice_prim(t572, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t574: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t575 = prims.slice_prim(t572, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t575: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t572\n", + " t576 = torch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t576 = ltorch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t576 = prims.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t573\n", + " t577 = torch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t577 = ltorch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t577 = prims.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t574\n", + " t578 = torch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t578 = ltorch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t578 = prims.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t575\n", + " t579 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t579: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t594 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t594: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t609 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t609: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t576\n", + " t611 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t611: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t577\n", + " t580 = torch_slice_prim_impl(t579, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t580: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t581 = torch_slice_prim_impl(t579, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t581: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t595 = torch_slice_prim_impl(t594, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t595: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t596 = torch_slice_prim_impl(t594, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t596: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t584, t599] = nvFusion21(t579, t581, t594, t596)\n", + " # t582 = prims.convert_element_type(t581, dtypes.float32) # t582: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t583 = prims.neg(t582) # t583: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t584 = prims.convert_element_type(t583, dtypes.bfloat16) # t584: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t597 = prims.convert_element_type(t596, dtypes.float32) # t597: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t598 = prims.neg(t597) # t598: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t599 = prims.convert_element_type(t598, dtypes.bfloat16) # t599: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t581, t596\n", + " t600 = torch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t600 = ltorch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t600 = prims.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t599, t595\n", + " t585 = torch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t585 = ltorch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t585 = prims.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t584, t580\n", + " [t593, t608] = nvFusion22(t154, t157, t579, t585, t594, t600)\n", + " # t587 = prims.convert_element_type(t579, dtypes.float32) # t587: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t602 = prims.convert_element_type(t594, dtypes.float32) # t602: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t603 = prims.mul(t602, t154) # t603: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t605 = prims.convert_element_type(t600, dtypes.float32) # t605: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t606 = prims.mul(t605, t157) # t606: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t607 = prims.add(t603, t606) # t607: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t608 = prims.convert_element_type(t607, dtypes.bfloat16) # t608: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t588 = prims.mul(t587, t154) # t588: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t590 = prims.convert_element_type(t585, dtypes.float32) # t590: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t591 = prims.mul(t590, t157) # t591: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t592 = prims.add(t588, t591) # t592: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t593 = prims.convert_element_type(t592, dtypes.bfloat16) # t593: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t579, t585, t594, t600\n", + " t612 = torch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t612 = ltorch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t612 = prims.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t608, t611\n", + " t610 = torch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t610 = ltorch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t610 = prims.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t593, t609\n", + " (t613, t614, t615, t616, _, _, t617, t618, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t610, t612, t578, 0.0, True, scale=0.08838834764831843)\n", + " t620 = torch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t620 = ltorch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t620 = prims.transpose(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t621 = torch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t621 = ltorch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t621 = prims.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t620\n", + " t622 = torch.nn.functional.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t622 = ltorch.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t622 = prims.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t626, t633, t641] = nvFusion23(t554, t622, t637)\n", + " # t624 = prims.convert_element_type(t554, dtypes.float32) # t624: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t623 = prims.convert_element_type(t622, dtypes.float32) # t623: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t625 = prims.add(t623, t624) # t625: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t626 = prims.convert_element_type(t625, dtypes.bfloat16) # t626: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t628 = prims.mul(t625, t625) # t628: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t629 = prims.sum(t628, (2,)) # t629: \"cuda:0 f32[1, 512]\"\n", + " # t630 = prims.broadcast_in_dim(t629, [1, 512, 1], [0, 1]) # t630: \"cuda:0 f32[1, 512, 1]\"\n", + " # t631 = prims.div(t630, 4096.0) # t631: \"cuda:0 f32[1, 512, 1]\"\n", + " # t632 = prims.add(t631, 1e-05) # t632: \"cuda:0 f32[1, 512, 1]\"\n", + " # t633 = prims.rsqrt(t632) # t633: \"cuda:0 f32[1, 512, 1]\"\n", + " # t634 = prims.broadcast_in_dim(t633, (1, 512, 4096), (0, 1, 2)) # t634: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t635 = prims.mul(t625, t634) # t635: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t639 = prims.convert_element_type(t637, dtypes.float32) # t639: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t640 = prims.mul(t635, t639) # t640: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t641 = prims.convert_element_type(t640, dtypes.bfloat16) # t641: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t643 = torch.nn.functional.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t643 = ltorch.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t643 = prims.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t642 = torch.nn.functional.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t642 = ltorch.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t642 = prims.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t657] = nvFusion24(t642, t643)\n", + " # t644 = prims.convert_element_type(t642, dtypes.float32) # t644: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t645 = prims.neg(t644) # t645: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t646 = prims.exp(t645) # t646: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t647 = prims.add(1.0, t646) # t647: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t648 = prims.reciprocal(t647) # t648: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t652 = prims.mul(t644, t648) # t652: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t655 = prims.convert_element_type(t643, dtypes.float32) # t655: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t656 = prims.mul(t652, t655) # t656: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t657 = prims.convert_element_type(t656, dtypes.bfloat16) # t657: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t658 = torch.nn.functional.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t658 = ltorch.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t658 = prims.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t662, t669, t677] = nvFusion25(t626, t658, t673)\n", + " # t660 = prims.convert_element_type(t626, dtypes.float32) # t660: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t659 = prims.convert_element_type(t658, dtypes.float32) # t659: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t661 = prims.add(t659, t660) # t661: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t662 = prims.convert_element_type(t661, dtypes.bfloat16) # t662: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t664 = prims.mul(t661, t661) # t664: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t665 = prims.sum(t664, (2,)) # t665: \"cuda:0 f32[1, 512]\"\n", + " # t666 = prims.broadcast_in_dim(t665, [1, 512, 1], [0, 1]) # t666: \"cuda:0 f32[1, 512, 1]\"\n", + " # t667 = prims.div(t666, 4096.0) # t667: \"cuda:0 f32[1, 512, 1]\"\n", + " # t668 = prims.add(t667, 1e-05) # t668: \"cuda:0 f32[1, 512, 1]\"\n", + " # t669 = prims.rsqrt(t668) # t669: \"cuda:0 f32[1, 512, 1]\"\n", + " # t670 = prims.broadcast_in_dim(t669, (1, 512, 4096), (0, 1, 2)) # t670: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t671 = prims.mul(t661, t670) # t671: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t675 = prims.convert_element_type(t673, dtypes.float32) # t675: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t676 = prims.mul(t671, t675) # t676: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t677 = prims.convert_element_type(t676, dtypes.bfloat16) # t677: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t678 = torch.nn.functional.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t678 = ltorch.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t678 = prims.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t679 = torch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t679 = ltorch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t679 = prims.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t678\n", + " t680 = torch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t680 = ltorch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t680 = prims.transpose(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t679\n", + " (t681, t682, t683) = torch.split(t680, (1, 1, 1), 2)\n", + " # (t681, t682, t683) = ltorch.split(t680, (1, 1, 1), 2)\n", + " # t681 = prims.slice_prim(t680, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t681: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t682 = prims.slice_prim(t680, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t682: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t683 = prims.slice_prim(t680, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t683: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t680\n", + " t684 = torch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t684 = ltorch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t684 = prims.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t681\n", + " t685 = torch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t685 = ltorch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t685 = prims.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t682\n", + " t686 = torch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t686 = ltorch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t686 = prims.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t683\n", + " t687 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t687: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t702 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t702: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t717 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t717: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t684\n", + " t719 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t719: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t685\n", + " t688 = torch_slice_prim_impl(t687, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t688: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t689 = torch_slice_prim_impl(t687, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t689: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t703 = torch_slice_prim_impl(t702, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t703: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t704 = torch_slice_prim_impl(t702, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t704: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t692, t707] = nvFusion26(t687, t689, t702, t704)\n", + " # t690 = prims.convert_element_type(t689, dtypes.float32) # t690: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t691 = prims.neg(t690) # t691: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t692 = prims.convert_element_type(t691, dtypes.bfloat16) # t692: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t705 = prims.convert_element_type(t704, dtypes.float32) # t705: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t706 = prims.neg(t705) # t706: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t707 = prims.convert_element_type(t706, dtypes.bfloat16) # t707: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t689, t704\n", + " t708 = torch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t708 = ltorch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t708 = prims.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t707, t703\n", + " t693 = torch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t693 = ltorch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t693 = prims.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t692, t688\n", + " [t701, t716] = nvFusion27(t154, t157, t687, t693, t702, t708)\n", + " # t695 = prims.convert_element_type(t687, dtypes.float32) # t695: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t710 = prims.convert_element_type(t702, dtypes.float32) # t710: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t711 = prims.mul(t710, t154) # t711: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t713 = prims.convert_element_type(t708, dtypes.float32) # t713: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t714 = prims.mul(t713, t157) # t714: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t715 = prims.add(t711, t714) # t715: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t716 = prims.convert_element_type(t715, dtypes.bfloat16) # t716: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t696 = prims.mul(t695, t154) # t696: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t698 = prims.convert_element_type(t693, dtypes.float32) # t698: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t699 = prims.mul(t698, t157) # t699: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t700 = prims.add(t696, t699) # t700: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t701 = prims.convert_element_type(t700, dtypes.bfloat16) # t701: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t687, t693, t702, t708\n", + " t720 = torch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t720 = ltorch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t720 = prims.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t716, t719\n", + " t718 = torch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t718 = ltorch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t718 = prims.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t701, t717\n", + " (t721, t722, t723, t724, _, _, t725, t726, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t718, t720, t686, 0.0, True, scale=0.08838834764831843)\n", + " t728 = torch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t728 = ltorch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t728 = prims.transpose(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t729 = torch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t729 = ltorch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t729 = prims.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t728\n", + " t730 = torch.nn.functional.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t730 = ltorch.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t730 = prims.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t734, t741, t749] = nvFusion28(t662, t730, t745)\n", + " # t732 = prims.convert_element_type(t662, dtypes.float32) # t732: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t731 = prims.convert_element_type(t730, dtypes.float32) # t731: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t733 = prims.add(t731, t732) # t733: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t734 = prims.convert_element_type(t733, dtypes.bfloat16) # t734: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t736 = prims.mul(t733, t733) # t736: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t737 = prims.sum(t736, (2,)) # t737: \"cuda:0 f32[1, 512]\"\n", + " # t738 = prims.broadcast_in_dim(t737, [1, 512, 1], [0, 1]) # t738: \"cuda:0 f32[1, 512, 1]\"\n", + " # t739 = prims.div(t738, 4096.0) # t739: \"cuda:0 f32[1, 512, 1]\"\n", + " # t740 = prims.add(t739, 1e-05) # t740: \"cuda:0 f32[1, 512, 1]\"\n", + " # t741 = prims.rsqrt(t740) # t741: \"cuda:0 f32[1, 512, 1]\"\n", + " # t742 = prims.broadcast_in_dim(t741, (1, 512, 4096), (0, 1, 2)) # t742: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t743 = prims.mul(t733, t742) # t743: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t747 = prims.convert_element_type(t745, dtypes.float32) # t747: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t748 = prims.mul(t743, t747) # t748: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t750 = torch.nn.functional.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t750 = ltorch.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t750 = prims.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t751 = torch.nn.functional.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t751 = ltorch.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t751 = prims.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t765] = nvFusion29(t750, t751)\n", + " # t752 = prims.convert_element_type(t750, dtypes.float32) # t752: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t753 = prims.neg(t752) # t753: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t754 = prims.exp(t753) # t754: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t755 = prims.add(1.0, t754) # t755: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t756 = prims.reciprocal(t755) # t756: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t760 = prims.mul(t752, t756) # t760: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t763 = prims.convert_element_type(t751, dtypes.float32) # t763: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t764 = prims.mul(t760, t763) # t764: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t765 = prims.convert_element_type(t764, dtypes.bfloat16) # t765: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t766 = torch.nn.functional.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t766 = ltorch.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t766 = prims.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t770, t777, t785] = nvFusion30(t734, t766, t781)\n", + " # t768 = prims.convert_element_type(t734, dtypes.float32) # t768: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t767 = prims.convert_element_type(t766, dtypes.float32) # t767: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t769 = prims.add(t767, t768) # t769: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t770 = prims.convert_element_type(t769, dtypes.bfloat16) # t770: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t772 = prims.mul(t769, t769) # t772: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t773 = prims.sum(t772, (2,)) # t773: \"cuda:0 f32[1, 512]\"\n", + " # t774 = prims.broadcast_in_dim(t773, [1, 512, 1], [0, 1]) # t774: \"cuda:0 f32[1, 512, 1]\"\n", + " # t775 = prims.div(t774, 4096.0) # t775: \"cuda:0 f32[1, 512, 1]\"\n", + " # t776 = prims.add(t775, 1e-05) # t776: \"cuda:0 f32[1, 512, 1]\"\n", + " # t777 = prims.rsqrt(t776) # t777: \"cuda:0 f32[1, 512, 1]\"\n", + " # t778 = prims.broadcast_in_dim(t777, (1, 512, 4096), (0, 1, 2)) # t778: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t779 = prims.mul(t769, t778) # t779: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t783 = prims.convert_element_type(t781, dtypes.float32) # t783: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t784 = prims.mul(t779, t783) # t784: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t785 = prims.convert_element_type(t784, dtypes.bfloat16) # t785: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t786 = torch.nn.functional.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t786 = ltorch.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t786 = prims.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t787 = torch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t787 = ltorch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t787 = prims.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t786\n", + " t788 = torch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t788 = ltorch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t788 = prims.transpose(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t787\n", + " (t789, t790, t791) = torch.split(t788, (1, 1, 1), 2)\n", + " # (t789, t790, t791) = ltorch.split(t788, (1, 1, 1), 2)\n", + " # t789 = prims.slice_prim(t788, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t789: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t790 = prims.slice_prim(t788, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t790: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t791 = prims.slice_prim(t788, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t791: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t788\n", + " t792 = torch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t792 = ltorch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t792 = prims.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t789\n", + " t793 = torch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t793 = ltorch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t793 = prims.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t790\n", + " t794 = torch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t794 = ltorch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t794 = prims.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t791\n", + " t795 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t795: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t810 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t810: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t825 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t825: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t792\n", + " t827 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t827: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t793\n", + " t796 = torch_slice_prim_impl(t795, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t796: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t797 = torch_slice_prim_impl(t795, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t797: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t811 = torch_slice_prim_impl(t810, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t811: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t812 = torch_slice_prim_impl(t810, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t812: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t800, t815] = nvFusion31(t795, t797, t810, t812)\n", + " # t798 = prims.convert_element_type(t797, dtypes.float32) # t798: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t799 = prims.neg(t798) # t799: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t800 = prims.convert_element_type(t799, dtypes.bfloat16) # t800: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t813 = prims.convert_element_type(t812, dtypes.float32) # t813: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t814 = prims.neg(t813) # t814: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t815 = prims.convert_element_type(t814, dtypes.bfloat16) # t815: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t797, t812\n", + " t816 = torch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t816 = ltorch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t816 = prims.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t815, t811\n", + " t801 = torch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t801 = ltorch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t801 = prims.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t800, t796\n", + " [t809, t824] = nvFusion32(t154, t157, t795, t801, t810, t816)\n", + " # t803 = prims.convert_element_type(t795, dtypes.float32) # t803: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t818 = prims.convert_element_type(t810, dtypes.float32) # t818: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t819 = prims.mul(t818, t154) # t819: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t821 = prims.convert_element_type(t816, dtypes.float32) # t821: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t822 = prims.mul(t821, t157) # t822: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t823 = prims.add(t819, t822) # t823: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t824 = prims.convert_element_type(t823, dtypes.bfloat16) # t824: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t804 = prims.mul(t803, t154) # t804: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t806 = prims.convert_element_type(t801, dtypes.float32) # t806: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t807 = prims.mul(t806, t157) # t807: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t808 = prims.add(t804, t807) # t808: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t809 = prims.convert_element_type(t808, dtypes.bfloat16) # t809: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t795, t801, t810, t816\n", + " t828 = torch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t828 = ltorch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t828 = prims.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t824, t827\n", + " t826 = torch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t826 = ltorch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t826 = prims.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t809, t825\n", + " (t829, t830, t831, t832, _, _, t833, t834, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t826, t828, t794, 0.0, True, scale=0.08838834764831843)\n", + " t836 = torch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t836 = ltorch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t836 = prims.transpose(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t837 = torch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t837 = ltorch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t837 = prims.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t836\n", + " t838 = torch.nn.functional.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t838 = ltorch.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t838 = prims.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t842, t849, t857] = nvFusion33(t770, t838, t853)\n", + " # t840 = prims.convert_element_type(t770, dtypes.float32) # t840: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t839 = prims.convert_element_type(t838, dtypes.float32) # t839: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t841 = prims.add(t839, t840) # t841: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t842 = prims.convert_element_type(t841, dtypes.bfloat16) # t842: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t844 = prims.mul(t841, t841) # t844: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t845 = prims.sum(t844, (2,)) # t845: \"cuda:0 f32[1, 512]\"\n", + " # t846 = prims.broadcast_in_dim(t845, [1, 512, 1], [0, 1]) # t846: \"cuda:0 f32[1, 512, 1]\"\n", + " # t847 = prims.div(t846, 4096.0) # t847: \"cuda:0 f32[1, 512, 1]\"\n", + " # t848 = prims.add(t847, 1e-05) # t848: \"cuda:0 f32[1, 512, 1]\"\n", + " # t849 = prims.rsqrt(t848) # t849: \"cuda:0 f32[1, 512, 1]\"\n", + " # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t851 = prims.mul(t841, t850) # t851: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t855 = prims.convert_element_type(t853, dtypes.float32) # t855: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t856 = prims.mul(t851, t855) # t856: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t857 = prims.convert_element_type(t856, dtypes.bfloat16) # t857: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t858 = torch.nn.functional.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t858 = ltorch.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t858 = prims.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t859 = torch.nn.functional.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t859 = ltorch.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t859 = prims.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t873] = nvFusion34(t858, t859)\n", + " # t860 = prims.convert_element_type(t858, dtypes.float32) # t860: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t861 = prims.neg(t860) # t861: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t862 = prims.exp(t861) # t862: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t863 = prims.add(1.0, t862) # t863: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t864 = prims.reciprocal(t863) # t864: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t868 = prims.mul(t860, t864) # t868: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t871 = prims.convert_element_type(t859, dtypes.float32) # t871: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t872 = prims.mul(t868, t871) # t872: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t873 = prims.convert_element_type(t872, dtypes.bfloat16) # t873: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t874 = torch.nn.functional.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t874 = ltorch.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t874 = prims.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t878, t885, t893] = nvFusion35(t842, t874, t889)\n", + " # t876 = prims.convert_element_type(t842, dtypes.float32) # t876: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t875 = prims.convert_element_type(t874, dtypes.float32) # t875: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t877 = prims.add(t875, t876) # t877: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t878 = prims.convert_element_type(t877, dtypes.bfloat16) # t878: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t880 = prims.mul(t877, t877) # t880: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t881 = prims.sum(t880, (2,)) # t881: \"cuda:0 f32[1, 512]\"\n", + " # t882 = prims.broadcast_in_dim(t881, [1, 512, 1], [0, 1]) # t882: \"cuda:0 f32[1, 512, 1]\"\n", + " # t883 = prims.div(t882, 4096.0) # t883: \"cuda:0 f32[1, 512, 1]\"\n", + " # t884 = prims.add(t883, 1e-05) # t884: \"cuda:0 f32[1, 512, 1]\"\n", + " # t885 = prims.rsqrt(t884) # t885: \"cuda:0 f32[1, 512, 1]\"\n", + " # t886 = prims.broadcast_in_dim(t885, (1, 512, 4096), (0, 1, 2)) # t886: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t887 = prims.mul(t877, t886) # t887: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t891 = prims.convert_element_type(t889, dtypes.float32) # t891: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t892 = prims.mul(t887, t891) # t892: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t893 = prims.convert_element_type(t892, dtypes.bfloat16) # t893: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t894 = torch.nn.functional.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t894 = ltorch.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t894 = prims.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t895 = torch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t895 = ltorch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t895 = prims.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t894\n", + " t896 = torch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t896 = ltorch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t896 = prims.transpose(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t895\n", + " (t897, t898, t899) = torch.split(t896, (1, 1, 1), 2)\n", + " # (t897, t898, t899) = ltorch.split(t896, (1, 1, 1), 2)\n", + " # t897 = prims.slice_prim(t896, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t897: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t898 = prims.slice_prim(t896, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t898: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t899 = prims.slice_prim(t896, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t899: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t896\n", + " t900 = torch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t900 = ltorch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t900 = prims.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t897\n", + " t901 = torch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t901 = ltorch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t901 = prims.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t898\n", + " t902 = torch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t902 = ltorch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t902 = prims.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t899\n", + " t935 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t935: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t903 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t903: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t918 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t901\n", + " t933 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t933: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t900\n", + " t904 = torch_slice_prim_impl(t903, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t904: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t905 = torch_slice_prim_impl(t903, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t905: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t919 = torch_slice_prim_impl(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t920 = torch_slice_prim_impl(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t908, t923] = nvFusion36(t903, t905, t918, t920)\n", + " # t906 = prims.convert_element_type(t905, dtypes.float32) # t906: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t907 = prims.neg(t906) # t907: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t908 = prims.convert_element_type(t907, dtypes.bfloat16) # t908: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t921 = prims.convert_element_type(t920, dtypes.float32) # t921: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t922 = prims.neg(t921) # t922: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t905, t920\n", + " t924 = torch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t924 = ltorch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t924 = prims.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t923, t919\n", + " t909 = torch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t909 = ltorch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t909 = prims.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t908, t904\n", + " [t917, t932] = nvFusion37(t154, t157, t903, t909, t918, t924)\n", + " # t911 = prims.convert_element_type(t903, dtypes.float32) # t911: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t926 = prims.convert_element_type(t918, dtypes.float32) # t926: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t927 = prims.mul(t926, t154) # t927: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t929 = prims.convert_element_type(t924, dtypes.float32) # t929: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t930 = prims.mul(t929, t157) # t930: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t931 = prims.add(t927, t930) # t931: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t932 = prims.convert_element_type(t931, dtypes.bfloat16) # t932: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t912 = prims.mul(t911, t154) # t912: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t914 = prims.convert_element_type(t909, dtypes.float32) # t914: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t915 = prims.mul(t914, t157) # t915: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t916 = prims.add(t912, t915) # t916: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t903, t909, t918, t924\n", + " t936 = torch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t936 = ltorch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t936 = prims.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t932, t935\n", + " t934 = torch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t934 = ltorch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t934 = prims.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t917, t933\n", + " (t937, t938, t939, t940, _, _, t941, t942, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t934, t936, t902, 0.0, True, scale=0.08838834764831843)\n", + " t944 = torch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t944 = ltorch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t944 = prims.transpose(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t945 = torch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t945 = ltorch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t945 = prims.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t944\n", + " t946 = torch.nn.functional.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t946 = ltorch.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t946 = prims.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t950, t957, t965] = nvFusion38(t878, t946, t961)\n", + " # t948 = prims.convert_element_type(t878, dtypes.float32) # t948: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t947 = prims.convert_element_type(t946, dtypes.float32) # t947: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t949 = prims.add(t947, t948) # t949: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t950 = prims.convert_element_type(t949, dtypes.bfloat16) # t950: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t952 = prims.mul(t949, t949) # t952: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t953 = prims.sum(t952, (2,)) # t953: \"cuda:0 f32[1, 512]\"\n", + " # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: \"cuda:0 f32[1, 512, 1]\"\n", + " # t955 = prims.div(t954, 4096.0) # t955: \"cuda:0 f32[1, 512, 1]\"\n", + " # t956 = prims.add(t955, 1e-05) # t956: \"cuda:0 f32[1, 512, 1]\"\n", + " # t957 = prims.rsqrt(t956) # t957: \"cuda:0 f32[1, 512, 1]\"\n", + " # t958 = prims.broadcast_in_dim(t957, (1, 512, 4096), (0, 1, 2)) # t958: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t959 = prims.mul(t949, t958) # t959: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t963 = prims.convert_element_type(t961, dtypes.float32) # t963: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t964 = prims.mul(t959, t963) # t964: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t965 = prims.convert_element_type(t964, dtypes.bfloat16) # t965: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t967 = torch.nn.functional.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t967 = ltorch.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t967 = prims.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t966 = torch.nn.functional.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t966 = ltorch.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t966 = prims.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t981] = nvFusion39(t966, t967)\n", + " # t968 = prims.convert_element_type(t966, dtypes.float32) # t968: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t969 = prims.neg(t968) # t969: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t970 = prims.exp(t969) # t970: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t971 = prims.add(1.0, t970) # t971: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t972 = prims.reciprocal(t971) # t972: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t976 = prims.mul(t968, t972) # t976: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t979 = prims.convert_element_type(t967, dtypes.float32) # t979: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t980 = prims.mul(t976, t979) # t980: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t982 = torch.nn.functional.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t982 = ltorch.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t982 = prims.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1001, t986, t993] = nvFusion40(t950, t982, t997)\n", + " # t984 = prims.convert_element_type(t950, dtypes.float32) # t984: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t983 = prims.convert_element_type(t982, dtypes.float32) # t983: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t985 = prims.add(t983, t984) # t985: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t988 = prims.mul(t985, t985) # t988: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t989 = prims.sum(t988, (2,)) # t989: \"cuda:0 f32[1, 512]\"\n", + " # t990 = prims.broadcast_in_dim(t989, [1, 512, 1], [0, 1]) # t990: \"cuda:0 f32[1, 512, 1]\"\n", + " # t991 = prims.div(t990, 4096.0) # t991: \"cuda:0 f32[1, 512, 1]\"\n", + " # t992 = prims.add(t991, 1e-05) # t992: \"cuda:0 f32[1, 512, 1]\"\n", + " # t993 = prims.rsqrt(t992) # t993: \"cuda:0 f32[1, 512, 1]\"\n", + " # t994 = prims.broadcast_in_dim(t993, (1, 512, 4096), (0, 1, 2)) # t994: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t995 = prims.mul(t985, t994) # t995: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t999 = prims.convert_element_type(t997, dtypes.float32) # t999: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1000 = prims.mul(t995, t999) # t1000: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1001 = prims.convert_element_type(t1000, dtypes.bfloat16) # t1001: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1002 = torch.nn.functional.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1002 = ltorch.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1002 = prims.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1003 = torch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1003 = ltorch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1003 = prims.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1002\n", + " t1004 = torch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1004 = ltorch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1004 = prims.transpose(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1003\n", + " (t1005, t1006, t1007) = torch.split(t1004, (1, 1, 1), 2)\n", + " # (t1005, t1006, t1007) = ltorch.split(t1004, (1, 1, 1), 2)\n", + " # t1005 = prims.slice_prim(t1004, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1005: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1006 = prims.slice_prim(t1004, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1006: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1007 = prims.slice_prim(t1004, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1007: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1004\n", + " t1008 = torch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1008 = ltorch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1008 = prims.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1005\n", + " t1009 = torch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1009 = ltorch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1009 = prims.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1006\n", + " t1010 = torch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1010 = ltorch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1010 = prims.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1007\n", + " t1026 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1026: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1041 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1041: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1043 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1043: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1009\n", + " t1011 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1011: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1008\n", + " t1027 = torch_slice_prim_impl(t1026, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1027: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1028 = torch_slice_prim_impl(t1026, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1028: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1013 = torch_slice_prim_impl(t1011, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1013: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1012 = torch_slice_prim_impl(t1011, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1012: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1016, t1031] = nvFusion41(t1011, t1013, t1026, t1028)\n", + " # t1014 = prims.convert_element_type(t1013, dtypes.float32) # t1014: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1015 = prims.neg(t1014) # t1015: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1016 = prims.convert_element_type(t1015, dtypes.bfloat16) # t1016: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1029 = prims.convert_element_type(t1028, dtypes.float32) # t1029: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1030 = prims.neg(t1029) # t1030: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1031 = prims.convert_element_type(t1030, dtypes.bfloat16) # t1031: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1013, t1028\n", + " t1032 = torch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1032 = ltorch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1032 = prims.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1031, t1027\n", + " t1017 = torch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1017 = ltorch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1017 = prims.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1016, t1012\n", + " [t1025, t1040] = nvFusion42(t1011, t1017, t1026, t1032, t154, t157)\n", + " # t1019 = prims.convert_element_type(t1011, dtypes.float32) # t1019: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1034 = prims.convert_element_type(t1026, dtypes.float32) # t1034: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1020 = prims.mul(t1019, t154) # t1020: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1022 = prims.convert_element_type(t1017, dtypes.float32) # t1022: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1023 = prims.mul(t1022, t157) # t1023: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1024 = prims.add(t1020, t1023) # t1024: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1035 = prims.mul(t1034, t154) # t1035: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1037 = prims.convert_element_type(t1032, dtypes.float32) # t1037: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1038 = prims.mul(t1037, t157) # t1038: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1039 = prims.add(t1035, t1038) # t1039: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1040 = prims.convert_element_type(t1039, dtypes.bfloat16) # t1040: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1011, t1017, t1026, t1032\n", + " t1042 = torch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1042 = ltorch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1042 = prims.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1025, t1041\n", + " t1044 = torch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1044 = ltorch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1044 = prims.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1040, t1043\n", + " (t1045, t1046, t1047, t1048, _, _, t1049, t1050, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1042, t1044, t1010, 0.0, True, scale=0.08838834764831843)\n", + " t1052 = torch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1052 = ltorch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1052 = prims.transpose(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1053 = torch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1053 = ltorch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1053 = prims.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1052\n", + " t1054 = torch.nn.functional.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1054 = ltorch.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1054 = prims.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1058, t1065, t1073] = nvFusion43(t1054, t1069, t986)\n", + " # t1056 = prims.convert_element_type(t986, dtypes.float32) # t1056: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1055 = prims.convert_element_type(t1054, dtypes.float32) # t1055: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1057 = prims.add(t1055, t1056) # t1057: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1058 = prims.convert_element_type(t1057, dtypes.bfloat16) # t1058: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1060 = prims.mul(t1057, t1057) # t1060: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1061 = prims.sum(t1060, (2,)) # t1061: \"cuda:0 f32[1, 512]\"\n", + " # t1062 = prims.broadcast_in_dim(t1061, [1, 512, 1], [0, 1]) # t1062: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1063 = prims.div(t1062, 4096.0) # t1063: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1064 = prims.add(t1063, 1e-05) # t1064: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1065 = prims.rsqrt(t1064) # t1065: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1066 = prims.broadcast_in_dim(t1065, (1, 512, 4096), (0, 1, 2)) # t1066: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1067 = prims.mul(t1057, t1066) # t1067: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1071 = prims.convert_element_type(t1069, dtypes.float32) # t1071: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1072 = prims.mul(t1067, t1071) # t1072: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1073 = prims.convert_element_type(t1072, dtypes.bfloat16) # t1073: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1074 = torch.nn.functional.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1074 = ltorch.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1074 = prims.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1075 = torch.nn.functional.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1075 = ltorch.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1075 = prims.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1089] = nvFusion44(t1074, t1075)\n", + " # t1076 = prims.convert_element_type(t1074, dtypes.float32) # t1076: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1077 = prims.neg(t1076) # t1077: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1078 = prims.exp(t1077) # t1078: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1079 = prims.add(1.0, t1078) # t1079: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1080 = prims.reciprocal(t1079) # t1080: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1084 = prims.mul(t1076, t1080) # t1084: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1087 = prims.convert_element_type(t1075, dtypes.float32) # t1087: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1088 = prims.mul(t1084, t1087) # t1088: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1089 = prims.convert_element_type(t1088, dtypes.bfloat16) # t1089: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1090 = torch.nn.functional.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1090 = ltorch.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1090 = prims.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1094, t1101, t1109] = nvFusion45(t1058, t1090, t1105)\n", + " # t1092 = prims.convert_element_type(t1058, dtypes.float32) # t1092: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1093 = prims.add(t1091, t1092) # t1093: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1094 = prims.convert_element_type(t1093, dtypes.bfloat16) # t1094: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1096 = prims.mul(t1093, t1093) # t1096: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1097 = prims.sum(t1096, (2,)) # t1097: \"cuda:0 f32[1, 512]\"\n", + " # t1098 = prims.broadcast_in_dim(t1097, [1, 512, 1], [0, 1]) # t1098: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1099 = prims.div(t1098, 4096.0) # t1099: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1100 = prims.add(t1099, 1e-05) # t1100: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1101 = prims.rsqrt(t1100) # t1101: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1102 = prims.broadcast_in_dim(t1101, (1, 512, 4096), (0, 1, 2)) # t1102: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1103 = prims.mul(t1093, t1102) # t1103: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1107 = prims.convert_element_type(t1105, dtypes.float32) # t1107: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1108 = prims.mul(t1103, t1107) # t1108: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1109 = prims.convert_element_type(t1108, dtypes.bfloat16) # t1109: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1110 = torch.nn.functional.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1110 = ltorch.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1110 = prims.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1111 = torch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1111 = ltorch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1111 = prims.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1110\n", + " t1112 = torch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1112 = ltorch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1112 = prims.transpose(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1111\n", + " (t1113, t1114, t1115) = torch.split(t1112, (1, 1, 1), 2)\n", + " # (t1113, t1114, t1115) = ltorch.split(t1112, (1, 1, 1), 2)\n", + " # t1113 = prims.slice_prim(t1112, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1113: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1114 = prims.slice_prim(t1112, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1114: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1115 = prims.slice_prim(t1112, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1115: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1112\n", + " t1116 = torch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1116 = ltorch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1116 = prims.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1113\n", + " t1117 = torch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1117 = ltorch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1117 = prims.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1114\n", + " t1118 = torch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1118 = ltorch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1118 = prims.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1115\n", + " t1119 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1119: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1134 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1134: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1149 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1149: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1116\n", + " t1151 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1151: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1117\n", + " t1120 = torch_slice_prim_impl(t1119, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1120: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1121 = torch_slice_prim_impl(t1119, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1121: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1136 = torch_slice_prim_impl(t1134, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1136: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1135 = torch_slice_prim_impl(t1134, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1135: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1124, t1139] = nvFusion46(t1119, t1121, t1134, t1136)\n", + " # t1122 = prims.convert_element_type(t1121, dtypes.float32) # t1122: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1123 = prims.neg(t1122) # t1123: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1137 = prims.convert_element_type(t1136, dtypes.float32) # t1137: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1138 = prims.neg(t1137) # t1138: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1121, t1136\n", + " t1125 = torch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1125 = ltorch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1125 = prims.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1124, t1120\n", + " t1140 = torch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1140 = ltorch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1140 = prims.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1139, t1135\n", + " [t1133, t1148] = nvFusion47(t1119, t1125, t1134, t1140, t154, t157)\n", + " # t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1142 = prims.convert_element_type(t1134, dtypes.float32) # t1142: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1128 = prims.mul(t1127, t154) # t1128: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1130 = prims.convert_element_type(t1125, dtypes.float32) # t1130: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1131 = prims.mul(t1130, t157) # t1131: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1132 = prims.add(t1128, t1131) # t1132: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1133 = prims.convert_element_type(t1132, dtypes.bfloat16) # t1133: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1143 = prims.mul(t1142, t154) # t1143: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1145 = prims.convert_element_type(t1140, dtypes.float32) # t1145: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1146 = prims.mul(t1145, t157) # t1146: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1147 = prims.add(t1143, t1146) # t1147: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1148 = prims.convert_element_type(t1147, dtypes.bfloat16) # t1148: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1119, t1125, t1134, t1140\n", + " t1152 = torch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1152 = ltorch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1152 = prims.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1148, t1151\n", + " t1150 = torch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1150 = ltorch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1150 = prims.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1133, t1149\n", + " (t1153, t1154, t1155, t1156, _, _, t1157, t1158, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1150, t1152, t1118, 0.0, True, scale=0.08838834764831843)\n", + " t1160 = torch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1160 = ltorch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1160 = prims.transpose(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1161 = torch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1161 = ltorch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1161 = prims.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1160\n", + " t1162 = torch.nn.functional.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1162 = ltorch.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1162 = prims.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1166, t1173, t1181] = nvFusion48(t1094, t1162, t1177)\n", + " # t1164 = prims.convert_element_type(t1094, dtypes.float32) # t1164: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1163 = prims.convert_element_type(t1162, dtypes.float32) # t1163: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1165 = prims.add(t1163, t1164) # t1165: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1166 = prims.convert_element_type(t1165, dtypes.bfloat16) # t1166: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1168 = prims.mul(t1165, t1165) # t1168: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1169 = prims.sum(t1168, (2,)) # t1169: \"cuda:0 f32[1, 512]\"\n", + " # t1170 = prims.broadcast_in_dim(t1169, [1, 512, 1], [0, 1]) # t1170: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1171 = prims.div(t1170, 4096.0) # t1171: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1172 = prims.add(t1171, 1e-05) # t1172: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1173 = prims.rsqrt(t1172) # t1173: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1174 = prims.broadcast_in_dim(t1173, (1, 512, 4096), (0, 1, 2)) # t1174: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1175 = prims.mul(t1165, t1174) # t1175: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1179 = prims.convert_element_type(t1177, dtypes.float32) # t1179: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1180 = prims.mul(t1175, t1179) # t1180: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1181 = prims.convert_element_type(t1180, dtypes.bfloat16) # t1181: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1182 = torch.nn.functional.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1182 = ltorch.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1182 = prims.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1183 = torch.nn.functional.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1183 = ltorch.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1183 = prims.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1197] = nvFusion49(t1182, t1183)\n", + " # t1184 = prims.convert_element_type(t1182, dtypes.float32) # t1184: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1185 = prims.neg(t1184) # t1185: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1186 = prims.exp(t1185) # t1186: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1187 = prims.add(1.0, t1186) # t1187: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1188 = prims.reciprocal(t1187) # t1188: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1192 = prims.mul(t1184, t1188) # t1192: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1195 = prims.convert_element_type(t1183, dtypes.float32) # t1195: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1196 = prims.mul(t1192, t1195) # t1196: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1198 = torch.nn.functional.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1198 = ltorch.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1198 = prims.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1202, t1209, t1217] = nvFusion50(t1166, t1198, t1213)\n", + " # t1200 = prims.convert_element_type(t1166, dtypes.float32) # t1200: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1199 = prims.convert_element_type(t1198, dtypes.float32) # t1199: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1201 = prims.add(t1199, t1200) # t1201: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1202 = prims.convert_element_type(t1201, dtypes.bfloat16) # t1202: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1204 = prims.mul(t1201, t1201) # t1204: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1205 = prims.sum(t1204, (2,)) # t1205: \"cuda:0 f32[1, 512]\"\n", + " # t1206 = prims.broadcast_in_dim(t1205, [1, 512, 1], [0, 1]) # t1206: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1207 = prims.div(t1206, 4096.0) # t1207: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1208 = prims.add(t1207, 1e-05) # t1208: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1209 = prims.rsqrt(t1208) # t1209: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1210 = prims.broadcast_in_dim(t1209, (1, 512, 4096), (0, 1, 2)) # t1210: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1211 = prims.mul(t1201, t1210) # t1211: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1215 = prims.convert_element_type(t1213, dtypes.float32) # t1215: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1216 = prims.mul(t1211, t1215) # t1216: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1217 = prims.convert_element_type(t1216, dtypes.bfloat16) # t1217: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1218 = torch.nn.functional.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1218 = ltorch.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1218 = prims.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1219 = torch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1219 = ltorch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1219 = prims.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1218\n", + " t1220 = torch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1220 = ltorch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1220 = prims.transpose(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1219\n", + " (t1221, t1222, t1223) = torch.split(t1220, (1, 1, 1), 2)\n", + " # (t1221, t1222, t1223) = ltorch.split(t1220, (1, 1, 1), 2)\n", + " # t1221 = prims.slice_prim(t1220, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1221: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1222 = prims.slice_prim(t1220, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1222: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1223 = prims.slice_prim(t1220, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1223: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1220\n", + " t1224 = torch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1224 = ltorch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1224 = prims.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1221\n", + " t1225 = torch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1225 = ltorch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1225 = prims.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1222\n", + " t1226 = torch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1226 = ltorch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1226 = prims.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1223\n", + " t1227 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1227: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1242 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1242: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1257 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1257: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1224\n", + " t1259 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1259: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1225\n", + " t1228 = torch_slice_prim_impl(t1227, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1228: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1229 = torch_slice_prim_impl(t1227, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1229: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1243 = torch_slice_prim_impl(t1242, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1243: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1244 = torch_slice_prim_impl(t1242, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1244: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1232, t1247] = nvFusion51(t1227, t1229, t1242, t1244)\n", + " # t1230 = prims.convert_element_type(t1229, dtypes.float32) # t1230: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1231 = prims.neg(t1230) # t1231: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1232 = prims.convert_element_type(t1231, dtypes.bfloat16) # t1232: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1245 = prims.convert_element_type(t1244, dtypes.float32) # t1245: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1246 = prims.neg(t1245) # t1246: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1247 = prims.convert_element_type(t1246, dtypes.bfloat16) # t1247: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1229, t1244\n", + " t1233 = torch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1233 = ltorch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1233 = prims.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1232, t1228\n", + " t1248 = torch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1248 = ltorch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1248 = prims.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1247, t1243\n", + " [t1241, t1256] = nvFusion52(t1227, t1233, t1242, t1248, t154, t157)\n", + " # t1235 = prims.convert_element_type(t1227, dtypes.float32) # t1235: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1250 = prims.convert_element_type(t1242, dtypes.float32) # t1250: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1236 = prims.mul(t1235, t154) # t1236: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1238 = prims.convert_element_type(t1233, dtypes.float32) # t1238: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1239 = prims.mul(t1238, t157) # t1239: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1240 = prims.add(t1236, t1239) # t1240: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1241 = prims.convert_element_type(t1240, dtypes.bfloat16) # t1241: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1251 = prims.mul(t1250, t154) # t1251: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1253 = prims.convert_element_type(t1248, dtypes.float32) # t1253: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1254 = prims.mul(t1253, t157) # t1254: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1255 = prims.add(t1251, t1254) # t1255: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1256 = prims.convert_element_type(t1255, dtypes.bfloat16) # t1256: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1227, t1233, t1242, t1248\n", + " t1258 = torch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1258 = ltorch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1258 = prims.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1241, t1257\n", + " t1260 = torch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1260 = ltorch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1260 = prims.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1256, t1259\n", + " (t1261, t1262, t1263, t1264, _, _, t1265, t1266, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1258, t1260, t1226, 0.0, True, scale=0.08838834764831843)\n", + " t1268 = torch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1268 = ltorch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1268 = prims.transpose(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1269 = torch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1269 = ltorch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1269 = prims.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1268\n", + " t1270 = torch.nn.functional.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1270 = ltorch.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1270 = prims.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1274, t1281, t1289] = nvFusion53(t1202, t1270, t1285)\n", + " # t1272 = prims.convert_element_type(t1202, dtypes.float32) # t1272: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1271 = prims.convert_element_type(t1270, dtypes.float32) # t1271: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1273 = prims.add(t1271, t1272) # t1273: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1274 = prims.convert_element_type(t1273, dtypes.bfloat16) # t1274: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1276 = prims.mul(t1273, t1273) # t1276: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1277 = prims.sum(t1276, (2,)) # t1277: \"cuda:0 f32[1, 512]\"\n", + " # t1278 = prims.broadcast_in_dim(t1277, [1, 512, 1], [0, 1]) # t1278: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1279 = prims.div(t1278, 4096.0) # t1279: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1280 = prims.add(t1279, 1e-05) # t1280: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1281 = prims.rsqrt(t1280) # t1281: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1282 = prims.broadcast_in_dim(t1281, (1, 512, 4096), (0, 1, 2)) # t1282: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1283 = prims.mul(t1273, t1282) # t1283: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1287 = prims.convert_element_type(t1285, dtypes.float32) # t1287: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1288 = prims.mul(t1283, t1287) # t1288: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1289 = prims.convert_element_type(t1288, dtypes.bfloat16) # t1289: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1290 = torch.nn.functional.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1290 = ltorch.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1290 = prims.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1291 = torch.nn.functional.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1291 = ltorch.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1291 = prims.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1305] = nvFusion54(t1290, t1291)\n", + " # t1292 = prims.convert_element_type(t1290, dtypes.float32) # t1292: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1293 = prims.neg(t1292) # t1293: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1294 = prims.exp(t1293) # t1294: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1295 = prims.add(1.0, t1294) # t1295: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1296 = prims.reciprocal(t1295) # t1296: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1300 = prims.mul(t1292, t1296) # t1300: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1303 = prims.convert_element_type(t1291, dtypes.float32) # t1303: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1304 = prims.mul(t1300, t1303) # t1304: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1305 = prims.convert_element_type(t1304, dtypes.bfloat16) # t1305: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1306 = torch.nn.functional.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1306 = ltorch.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1306 = prims.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1310, t1317, t1325] = nvFusion55(t1274, t1306, t1321)\n", + " # t1308 = prims.convert_element_type(t1274, dtypes.float32) # t1308: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1307 = prims.convert_element_type(t1306, dtypes.float32) # t1307: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1309 = prims.add(t1307, t1308) # t1309: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1310 = prims.convert_element_type(t1309, dtypes.bfloat16) # t1310: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1312 = prims.mul(t1309, t1309) # t1312: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1313 = prims.sum(t1312, (2,)) # t1313: \"cuda:0 f32[1, 512]\"\n", + " # t1314 = prims.broadcast_in_dim(t1313, [1, 512, 1], [0, 1]) # t1314: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1315 = prims.div(t1314, 4096.0) # t1315: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1316 = prims.add(t1315, 1e-05) # t1316: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1317 = prims.rsqrt(t1316) # t1317: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1318 = prims.broadcast_in_dim(t1317, (1, 512, 4096), (0, 1, 2)) # t1318: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1319 = prims.mul(t1309, t1318) # t1319: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1323 = prims.convert_element_type(t1321, dtypes.float32) # t1323: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1324 = prims.mul(t1319, t1323) # t1324: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1325 = prims.convert_element_type(t1324, dtypes.bfloat16) # t1325: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1326 = torch.nn.functional.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1326 = ltorch.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1326 = prims.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1327 = torch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1327 = ltorch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1327 = prims.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1326\n", + " t1328 = torch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1328 = ltorch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1328 = prims.transpose(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1327\n", + " (t1329, t1330, t1331) = torch.split(t1328, (1, 1, 1), 2)\n", + " # (t1329, t1330, t1331) = ltorch.split(t1328, (1, 1, 1), 2)\n", + " # t1329 = prims.slice_prim(t1328, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1329: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1330 = prims.slice_prim(t1328, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1330: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1331 = prims.slice_prim(t1328, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1331: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1328\n", + " t1332 = torch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1332 = ltorch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1332 = prims.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1329\n", + " t1333 = torch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1333 = ltorch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1333 = prims.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1330\n", + " t1334 = torch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1334 = ltorch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1334 = prims.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1331\n", + " t1335 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1335: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1350 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1350: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1365 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1365: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1332\n", + " t1367 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1367: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1333\n", + " t1336 = torch_slice_prim_impl(t1335, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1336: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1337 = torch_slice_prim_impl(t1335, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1351 = torch_slice_prim_impl(t1350, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1351: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1352 = torch_slice_prim_impl(t1350, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1352: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1340, t1355] = nvFusion56(t1335, t1337, t1350, t1352)\n", + " # t1338 = prims.convert_element_type(t1337, dtypes.float32) # t1338: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1339 = prims.neg(t1338) # t1339: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1340 = prims.convert_element_type(t1339, dtypes.bfloat16) # t1340: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1353 = prims.convert_element_type(t1352, dtypes.float32) # t1353: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1354 = prims.neg(t1353) # t1354: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1355 = prims.convert_element_type(t1354, dtypes.bfloat16) # t1355: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1337, t1352\n", + " t1341 = torch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1341 = ltorch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1341 = prims.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1340, t1336\n", + " t1356 = torch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1356 = ltorch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1356 = prims.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1355, t1351\n", + " [t1349, t1364] = nvFusion57(t1335, t1341, t1350, t1356, t154, t157)\n", + " # t1343 = prims.convert_element_type(t1335, dtypes.float32) # t1343: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1358 = prims.convert_element_type(t1350, dtypes.float32) # t1358: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1344 = prims.mul(t1343, t154) # t1344: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1346 = prims.convert_element_type(t1341, dtypes.float32) # t1346: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1347 = prims.mul(t1346, t157) # t1347: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1348 = prims.add(t1344, t1347) # t1348: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1349 = prims.convert_element_type(t1348, dtypes.bfloat16) # t1349: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1359 = prims.mul(t1358, t154) # t1359: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1361 = prims.convert_element_type(t1356, dtypes.float32) # t1361: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1362 = prims.mul(t1361, t157) # t1362: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1363 = prims.add(t1359, t1362) # t1363: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1364 = prims.convert_element_type(t1363, dtypes.bfloat16) # t1364: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1335, t1341, t1350, t1356\n", + " t1366 = torch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1366 = ltorch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1366 = prims.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1349, t1365\n", + " t1368 = torch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1368 = ltorch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1368 = prims.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1364, t1367\n", + " (t1369, t1370, t1371, t1372, _, _, t1373, t1374, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1366, t1368, t1334, 0.0, True, scale=0.08838834764831843)\n", + " t1376 = torch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1376 = ltorch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1376 = prims.transpose(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1377 = torch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1377 = ltorch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1377 = prims.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1376\n", + " t1378 = torch.nn.functional.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1378 = ltorch.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1378 = prims.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1382, t1389, t1397] = nvFusion58(t1310, t1378, t1393)\n", + " # t1380 = prims.convert_element_type(t1310, dtypes.float32) # t1380: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1379 = prims.convert_element_type(t1378, dtypes.float32) # t1379: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1381 = prims.add(t1379, t1380) # t1381: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1382 = prims.convert_element_type(t1381, dtypes.bfloat16) # t1382: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1384 = prims.mul(t1381, t1381) # t1384: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1385 = prims.sum(t1384, (2,)) # t1385: \"cuda:0 f32[1, 512]\"\n", + " # t1386 = prims.broadcast_in_dim(t1385, [1, 512, 1], [0, 1]) # t1386: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1387 = prims.div(t1386, 4096.0) # t1387: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1388 = prims.add(t1387, 1e-05) # t1388: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1389 = prims.rsqrt(t1388) # t1389: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1390 = prims.broadcast_in_dim(t1389, (1, 512, 4096), (0, 1, 2)) # t1390: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1391 = prims.mul(t1381, t1390) # t1391: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1395 = prims.convert_element_type(t1393, dtypes.float32) # t1395: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1396 = prims.mul(t1391, t1395) # t1396: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1397 = prims.convert_element_type(t1396, dtypes.bfloat16) # t1397: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1398 = torch.nn.functional.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1398 = ltorch.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1398 = prims.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1399 = torch.nn.functional.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1399 = ltorch.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1399 = prims.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1413] = nvFusion59(t1398, t1399)\n", + " # t1400 = prims.convert_element_type(t1398, dtypes.float32) # t1400: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1401 = prims.neg(t1400) # t1401: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1402 = prims.exp(t1401) # t1402: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1403 = prims.add(1.0, t1402) # t1403: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1404 = prims.reciprocal(t1403) # t1404: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1408 = prims.mul(t1400, t1404) # t1408: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1411 = prims.convert_element_type(t1399, dtypes.float32) # t1411: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1412 = prims.mul(t1408, t1411) # t1412: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1413 = prims.convert_element_type(t1412, dtypes.bfloat16) # t1413: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1414 = torch.nn.functional.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1414 = ltorch.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1414 = prims.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1418, t1425, t1433] = nvFusion60(t1382, t1414, t1429)\n", + " # t1416 = prims.convert_element_type(t1382, dtypes.float32) # t1416: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1415 = prims.convert_element_type(t1414, dtypes.float32) # t1415: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1417 = prims.add(t1415, t1416) # t1417: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1418 = prims.convert_element_type(t1417, dtypes.bfloat16) # t1418: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1420 = prims.mul(t1417, t1417) # t1420: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1421 = prims.sum(t1420, (2,)) # t1421: \"cuda:0 f32[1, 512]\"\n", + " # t1422 = prims.broadcast_in_dim(t1421, [1, 512, 1], [0, 1]) # t1422: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1423 = prims.div(t1422, 4096.0) # t1423: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1424 = prims.add(t1423, 1e-05) # t1424: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1425 = prims.rsqrt(t1424) # t1425: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1426 = prims.broadcast_in_dim(t1425, (1, 512, 4096), (0, 1, 2)) # t1426: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1427 = prims.mul(t1417, t1426) # t1427: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1431 = prims.convert_element_type(t1429, dtypes.float32) # t1431: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1432 = prims.mul(t1427, t1431) # t1432: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1433 = prims.convert_element_type(t1432, dtypes.bfloat16) # t1433: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1434 = torch.nn.functional.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1434 = ltorch.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1434 = prims.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1435 = torch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1435 = ltorch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1435 = prims.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1434\n", + " t1436 = torch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1436 = ltorch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1436 = prims.transpose(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1435\n", + " (t1437, t1438, t1439) = torch.split(t1436, (1, 1, 1), 2)\n", + " # (t1437, t1438, t1439) = ltorch.split(t1436, (1, 1, 1), 2)\n", + " # t1437 = prims.slice_prim(t1436, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1437: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1438 = prims.slice_prim(t1436, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1438: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1439 = prims.slice_prim(t1436, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1439: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1436\n", + " t1440 = torch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1440 = ltorch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1440 = prims.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1437\n", + " t1441 = torch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1441 = ltorch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1441 = prims.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1438\n", + " t1442 = torch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1442 = ltorch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1442 = prims.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1439\n", + " t1443 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1443: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1458 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1458: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1473 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1473: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1440\n", + " t1475 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1475: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1441\n", + " t1444 = torch_slice_prim_impl(t1443, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1444: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1445 = torch_slice_prim_impl(t1443, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1445: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1459 = torch_slice_prim_impl(t1458, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1459: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1460 = torch_slice_prim_impl(t1458, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1460: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1448, t1463] = nvFusion61(t1443, t1445, t1458, t1460)\n", + " # t1446 = prims.convert_element_type(t1445, dtypes.float32) # t1446: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1447 = prims.neg(t1446) # t1447: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1448 = prims.convert_element_type(t1447, dtypes.bfloat16) # t1448: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1461 = prims.convert_element_type(t1460, dtypes.float32) # t1461: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1462 = prims.neg(t1461) # t1462: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1463 = prims.convert_element_type(t1462, dtypes.bfloat16) # t1463: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1445, t1460\n", + " t1464 = torch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1464 = ltorch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1464 = prims.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1463, t1459\n", + " t1449 = torch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1449 = ltorch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1449 = prims.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1448, t1444\n", + " [t1457, t1472] = nvFusion62(t1443, t1449, t1458, t1464, t154, t157)\n", + " # t1451 = prims.convert_element_type(t1443, dtypes.float32) # t1451: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1466 = prims.convert_element_type(t1458, dtypes.float32) # t1466: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1467 = prims.mul(t1466, t154) # t1467: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1469 = prims.convert_element_type(t1464, dtypes.float32) # t1469: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1470 = prims.mul(t1469, t157) # t1470: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1471 = prims.add(t1467, t1470) # t1471: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1472 = prims.convert_element_type(t1471, dtypes.bfloat16) # t1472: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1452 = prims.mul(t1451, t154) # t1452: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1454 = prims.convert_element_type(t1449, dtypes.float32) # t1454: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1455 = prims.mul(t1454, t157) # t1455: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1456 = prims.add(t1452, t1455) # t1456: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1457 = prims.convert_element_type(t1456, dtypes.bfloat16) # t1457: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1443, t1449, t1458, t1464\n", + " t1476 = torch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1476 = ltorch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1476 = prims.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1472, t1475\n", + " t1474 = torch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1474 = ltorch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1474 = prims.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1457, t1473\n", + " (t1477, t1478, t1479, t1480, _, _, t1481, t1482, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1474, t1476, t1442, 0.0, True, scale=0.08838834764831843)\n", + " t1484 = torch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1484 = ltorch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1484 = prims.transpose(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1485 = torch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1485 = ltorch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1485 = prims.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1484\n", + " t1486 = torch.nn.functional.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1486 = ltorch.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1486 = prims.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1490, t1497, t1505] = nvFusion63(t1418, t1486, t1501)\n", + " # t1488 = prims.convert_element_type(t1418, dtypes.float32) # t1488: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1487 = prims.convert_element_type(t1486, dtypes.float32) # t1487: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1489 = prims.add(t1487, t1488) # t1489: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1490 = prims.convert_element_type(t1489, dtypes.bfloat16) # t1490: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1492 = prims.mul(t1489, t1489) # t1492: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1493 = prims.sum(t1492, (2,)) # t1493: \"cuda:0 f32[1, 512]\"\n", + " # t1494 = prims.broadcast_in_dim(t1493, [1, 512, 1], [0, 1]) # t1494: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1495 = prims.div(t1494, 4096.0) # t1495: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1496 = prims.add(t1495, 1e-05) # t1496: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1497 = prims.rsqrt(t1496) # t1497: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1498 = prims.broadcast_in_dim(t1497, (1, 512, 4096), (0, 1, 2)) # t1498: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1499 = prims.mul(t1489, t1498) # t1499: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1503 = prims.convert_element_type(t1501, dtypes.float32) # t1503: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1504 = prims.mul(t1499, t1503) # t1504: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1505 = prims.convert_element_type(t1504, dtypes.bfloat16) # t1505: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1506 = torch.nn.functional.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1506 = ltorch.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1506 = prims.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1507 = torch.nn.functional.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1507 = ltorch.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1507 = prims.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1521] = nvFusion64(t1506, t1507)\n", + " # t1508 = prims.convert_element_type(t1506, dtypes.float32) # t1508: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1509 = prims.neg(t1508) # t1509: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1510 = prims.exp(t1509) # t1510: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1511 = prims.add(1.0, t1510) # t1511: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1512 = prims.reciprocal(t1511) # t1512: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1516 = prims.mul(t1508, t1512) # t1516: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1519 = prims.convert_element_type(t1507, dtypes.float32) # t1519: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1520 = prims.mul(t1516, t1519) # t1520: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1521 = prims.convert_element_type(t1520, dtypes.bfloat16) # t1521: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1522 = torch.nn.functional.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1522 = ltorch.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1522 = prims.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1526, t1533, t1541] = nvFusion65(t1490, t1522, t1537)\n", + " # t1524 = prims.convert_element_type(t1490, dtypes.float32) # t1524: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1523 = prims.convert_element_type(t1522, dtypes.float32) # t1523: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1525 = prims.add(t1523, t1524) # t1525: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1526 = prims.convert_element_type(t1525, dtypes.bfloat16) # t1526: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1528 = prims.mul(t1525, t1525) # t1528: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1529 = prims.sum(t1528, (2,)) # t1529: \"cuda:0 f32[1, 512]\"\n", + " # t1530 = prims.broadcast_in_dim(t1529, [1, 512, 1], [0, 1]) # t1530: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1531 = prims.div(t1530, 4096.0) # t1531: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1532 = prims.add(t1531, 1e-05) # t1532: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1533 = prims.rsqrt(t1532) # t1533: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1534 = prims.broadcast_in_dim(t1533, (1, 512, 4096), (0, 1, 2)) # t1534: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1535 = prims.mul(t1525, t1534) # t1535: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1539 = prims.convert_element_type(t1537, dtypes.float32) # t1539: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1540 = prims.mul(t1535, t1539) # t1540: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1541 = prims.convert_element_type(t1540, dtypes.bfloat16) # t1541: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1542 = torch.nn.functional.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1542 = ltorch.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1542 = prims.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1543 = torch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1543 = ltorch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1543 = prims.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1542\n", + " t1544 = torch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1544 = ltorch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1544 = prims.transpose(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1543\n", + " (t1545, t1546, t1547) = torch.split(t1544, (1, 1, 1), 2)\n", + " # (t1545, t1546, t1547) = ltorch.split(t1544, (1, 1, 1), 2)\n", + " # t1545 = prims.slice_prim(t1544, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1545: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1546 = prims.slice_prim(t1544, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1546: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1547 = prims.slice_prim(t1544, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1547: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1544\n", + " t1548 = torch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1548 = ltorch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1548 = prims.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1545\n", + " t1549 = torch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1549 = ltorch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1549 = prims.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1546\n", + " t1550 = torch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1550 = ltorch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1550 = prims.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1547\n", + " t1551 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1551: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1566 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1566: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1581 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1581: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1548\n", + " t1583 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1583: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1549\n", + " t1552 = torch_slice_prim_impl(t1551, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1552: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1553 = torch_slice_prim_impl(t1551, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1553: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1567 = torch_slice_prim_impl(t1566, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1567: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1568 = torch_slice_prim_impl(t1566, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1568: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1556, t1571] = nvFusion66(t1551, t1553, t1566, t1568)\n", + " # t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1555 = prims.neg(t1554) # t1555: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1556 = prims.convert_element_type(t1555, dtypes.bfloat16) # t1556: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1569 = prims.convert_element_type(t1568, dtypes.float32) # t1569: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1570 = prims.neg(t1569) # t1570: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1571 = prims.convert_element_type(t1570, dtypes.bfloat16) # t1571: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1553, t1568\n", + " t1572 = torch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1572 = ltorch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1572 = prims.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1571, t1567\n", + " t1557 = torch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1557 = ltorch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1557 = prims.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1556, t1552\n", + " [t1565, t1580] = nvFusion67(t154, t1551, t1557, t1566, t157, t1572)\n", + " # t1559 = prims.convert_element_type(t1551, dtypes.float32) # t1559: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1574 = prims.convert_element_type(t1566, dtypes.float32) # t1574: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1575 = prims.mul(t1574, t154) # t1575: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1577 = prims.convert_element_type(t1572, dtypes.float32) # t1577: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1578 = prims.mul(t1577, t157) # t1578: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1579 = prims.add(t1575, t1578) # t1579: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1580 = prims.convert_element_type(t1579, dtypes.bfloat16) # t1580: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1560 = prims.mul(t1559, t154) # t1560: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1562 = prims.convert_element_type(t1557, dtypes.float32) # t1562: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1563 = prims.mul(t1562, t157) # t1563: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1564 = prims.add(t1560, t1563) # t1564: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1551, t1557, t1566, t1572\n", + " t1584 = torch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1584 = ltorch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1584 = prims.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1580, t1583\n", + " t1582 = torch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1582 = ltorch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1582 = prims.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1565, t1581\n", + " (t1585, t1586, t1587, t1588, _, _, t1589, t1590, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1582, t1584, t1550, 0.0, True, scale=0.08838834764831843)\n", + " t1592 = torch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1592 = ltorch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1592 = prims.transpose(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1593 = torch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1593 = ltorch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1592\n", + " t1594 = torch.nn.functional.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1594 = ltorch.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1594 = prims.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1598, t1605, t1613] = nvFusion68(t1526, t1594, t1609)\n", + " # t1596 = prims.convert_element_type(t1526, dtypes.float32) # t1596: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1595 = prims.convert_element_type(t1594, dtypes.float32) # t1595: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1597 = prims.add(t1595, t1596) # t1597: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1598 = prims.convert_element_type(t1597, dtypes.bfloat16) # t1598: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1600 = prims.mul(t1597, t1597) # t1600: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1601 = prims.sum(t1600, (2,)) # t1601: \"cuda:0 f32[1, 512]\"\n", + " # t1602 = prims.broadcast_in_dim(t1601, [1, 512, 1], [0, 1]) # t1602: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1603 = prims.div(t1602, 4096.0) # t1603: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1604 = prims.add(t1603, 1e-05) # t1604: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1605 = prims.rsqrt(t1604) # t1605: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1606 = prims.broadcast_in_dim(t1605, (1, 512, 4096), (0, 1, 2)) # t1606: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1607 = prims.mul(t1597, t1606) # t1607: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1611 = prims.convert_element_type(t1609, dtypes.float32) # t1611: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1612 = prims.mul(t1607, t1611) # t1612: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1614 = torch.nn.functional.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1614 = ltorch.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1614 = prims.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1615 = torch.nn.functional.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1615 = ltorch.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1615 = prims.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1629] = nvFusion69(t1614, t1615)\n", + " # t1616 = prims.convert_element_type(t1614, dtypes.float32) # t1616: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1617 = prims.neg(t1616) # t1617: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1618 = prims.exp(t1617) # t1618: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1619 = prims.add(1.0, t1618) # t1619: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1620 = prims.reciprocal(t1619) # t1620: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1624 = prims.mul(t1616, t1620) # t1624: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1627 = prims.convert_element_type(t1615, dtypes.float32) # t1627: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1628 = prims.mul(t1624, t1627) # t1628: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1629 = prims.convert_element_type(t1628, dtypes.bfloat16) # t1629: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1630 = torch.nn.functional.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1630 = ltorch.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1630 = prims.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1634, t1641, t1649] = nvFusion70(t1598, t1630, t1645)\n", + " # t1632 = prims.convert_element_type(t1598, dtypes.float32) # t1632: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1631 = prims.convert_element_type(t1630, dtypes.float32) # t1631: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1633 = prims.add(t1631, t1632) # t1633: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1634 = prims.convert_element_type(t1633, dtypes.bfloat16) # t1634: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1636 = prims.mul(t1633, t1633) # t1636: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1637 = prims.sum(t1636, (2,)) # t1637: \"cuda:0 f32[1, 512]\"\n", + " # t1638 = prims.broadcast_in_dim(t1637, [1, 512, 1], [0, 1]) # t1638: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1639 = prims.div(t1638, 4096.0) # t1639: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1640 = prims.add(t1639, 1e-05) # t1640: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1641 = prims.rsqrt(t1640) # t1641: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1642 = prims.broadcast_in_dim(t1641, (1, 512, 4096), (0, 1, 2)) # t1642: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1643 = prims.mul(t1633, t1642) # t1643: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1647 = prims.convert_element_type(t1645, dtypes.float32) # t1647: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1648 = prims.mul(t1643, t1647) # t1648: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1649 = prims.convert_element_type(t1648, dtypes.bfloat16) # t1649: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1650 = torch.nn.functional.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1650 = ltorch.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1650 = prims.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1651 = torch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1651 = ltorch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1651 = prims.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1650\n", + " t1652 = torch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1652 = ltorch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1652 = prims.transpose(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1651\n", + " (t1653, t1654, t1655) = torch.split(t1652, (1, 1, 1), 2)\n", + " # (t1653, t1654, t1655) = ltorch.split(t1652, (1, 1, 1), 2)\n", + " # t1653 = prims.slice_prim(t1652, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1653: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1654 = prims.slice_prim(t1652, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1654: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1655 = prims.slice_prim(t1652, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1655: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1652\n", + " t1656 = torch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1656 = ltorch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1656 = prims.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1653\n", + " t1657 = torch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1657 = ltorch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1657 = prims.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1654\n", + " t1658 = torch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1658 = ltorch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1658 = prims.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1655\n", + " t1689 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1689: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1691 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1691: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1659 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1659: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1656\n", + " t1674 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1674: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1657\n", + " t1660 = torch_slice_prim_impl(t1659, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1660: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1661 = torch_slice_prim_impl(t1659, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1661: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1675: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1676: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1664, t1679] = nvFusion71(t1659, t1661, t1674, t1676)\n", + " # t1662 = prims.convert_element_type(t1661, dtypes.float32) # t1662: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1663 = prims.neg(t1662) # t1663: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1664 = prims.convert_element_type(t1663, dtypes.bfloat16) # t1664: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1678 = prims.neg(t1677) # t1678: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1679 = prims.convert_element_type(t1678, dtypes.bfloat16) # t1679: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1661, t1676\n", + " t1680 = torch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1680 = ltorch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1680 = prims.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1679, t1675\n", + " t1665 = torch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1665 = ltorch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1665 = prims.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1664, t1660\n", + " [t1673, t1688] = nvFusion72(t154, t157, t1659, t1665, t1674, t1680)\n", + " # t1667 = prims.convert_element_type(t1659, dtypes.float32) # t1667: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1682 = prims.convert_element_type(t1674, dtypes.float32) # t1682: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1683 = prims.mul(t1682, t154) # t1683: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1685 = prims.convert_element_type(t1680, dtypes.float32) # t1685: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1686 = prims.mul(t1685, t157) # t1686: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1687 = prims.add(t1683, t1686) # t1687: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1688 = prims.convert_element_type(t1687, dtypes.bfloat16) # t1688: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1668 = prims.mul(t1667, t154) # t1668: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1670 = prims.convert_element_type(t1665, dtypes.float32) # t1670: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1671 = prims.mul(t1670, t157) # t1671: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1672 = prims.add(t1668, t1671) # t1672: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1673 = prims.convert_element_type(t1672, dtypes.bfloat16) # t1673: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1659, t1665, t1674, t1680\n", + " t1692 = torch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1692 = ltorch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1692 = prims.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1688, t1691\n", + " t1690 = torch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1690 = ltorch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1690 = prims.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1673, t1689\n", + " (t1693, t1694, t1695, t1696, _, _, t1697, t1698, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1690, t1692, t1658, 0.0, True, scale=0.08838834764831843)\n", + " t1700 = torch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1700 = ltorch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1700 = prims.transpose(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1701 = torch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1701 = ltorch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1701 = prims.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1700\n", + " t1702 = torch.nn.functional.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1702 = ltorch.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1702 = prims.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1706, t1713, t1721] = nvFusion73(t1634, t1702, t1717)\n", + " # t1704 = prims.convert_element_type(t1634, dtypes.float32) # t1704: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1703 = prims.convert_element_type(t1702, dtypes.float32) # t1703: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1705 = prims.add(t1703, t1704) # t1705: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1708 = prims.mul(t1705, t1705) # t1708: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1709 = prims.sum(t1708, (2,)) # t1709: \"cuda:0 f32[1, 512]\"\n", + " # t1710 = prims.broadcast_in_dim(t1709, [1, 512, 1], [0, 1]) # t1710: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1711 = prims.div(t1710, 4096.0) # t1711: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1712 = prims.add(t1711, 1e-05) # t1712: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1713 = prims.rsqrt(t1712) # t1713: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1714 = prims.broadcast_in_dim(t1713, (1, 512, 4096), (0, 1, 2)) # t1714: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1715 = prims.mul(t1705, t1714) # t1715: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1719 = prims.convert_element_type(t1717, dtypes.float32) # t1719: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1720 = prims.mul(t1715, t1719) # t1720: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1721 = prims.convert_element_type(t1720, dtypes.bfloat16) # t1721: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1722 = torch.nn.functional.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1722 = ltorch.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1722 = prims.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1723 = torch.nn.functional.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1723 = ltorch.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1723 = prims.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1737] = nvFusion74(t1722, t1723)\n", + " # t1724 = prims.convert_element_type(t1722, dtypes.float32) # t1724: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1725 = prims.neg(t1724) # t1725: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1726 = prims.exp(t1725) # t1726: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1727 = prims.add(1.0, t1726) # t1727: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1728 = prims.reciprocal(t1727) # t1728: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1732 = prims.mul(t1724, t1728) # t1732: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1735 = prims.convert_element_type(t1723, dtypes.float32) # t1735: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1736 = prims.mul(t1732, t1735) # t1736: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1738 = torch.nn.functional.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1738 = ltorch.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1738 = prims.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1742, t1749, t1757] = nvFusion75(t1706, t1738, t1753)\n", + " # t1740 = prims.convert_element_type(t1706, dtypes.float32) # t1740: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1739 = prims.convert_element_type(t1738, dtypes.float32) # t1739: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1741 = prims.add(t1739, t1740) # t1741: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1742 = prims.convert_element_type(t1741, dtypes.bfloat16) # t1742: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1744 = prims.mul(t1741, t1741) # t1744: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1745 = prims.sum(t1744, (2,)) # t1745: \"cuda:0 f32[1, 512]\"\n", + " # t1746 = prims.broadcast_in_dim(t1745, [1, 512, 1], [0, 1]) # t1746: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1747 = prims.div(t1746, 4096.0) # t1747: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1748 = prims.add(t1747, 1e-05) # t1748: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1749 = prims.rsqrt(t1748) # t1749: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1750 = prims.broadcast_in_dim(t1749, (1, 512, 4096), (0, 1, 2)) # t1750: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1751 = prims.mul(t1741, t1750) # t1751: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1755 = prims.convert_element_type(t1753, dtypes.float32) # t1755: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1756 = prims.mul(t1751, t1755) # t1756: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1757 = prims.convert_element_type(t1756, dtypes.bfloat16) # t1757: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1758 = torch.nn.functional.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1758 = ltorch.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1758 = prims.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1759 = torch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1759 = ltorch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1759 = prims.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1758\n", + " t1760 = torch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1760 = ltorch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1760 = prims.transpose(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1759\n", + " (t1761, t1762, t1763) = torch.split(t1760, (1, 1, 1), 2)\n", + " # (t1761, t1762, t1763) = ltorch.split(t1760, (1, 1, 1), 2)\n", + " # t1761 = prims.slice_prim(t1760, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1761: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1762 = prims.slice_prim(t1760, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1762: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1763 = prims.slice_prim(t1760, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1763: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1760\n", + " t1764 = torch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1764 = ltorch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1764 = prims.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1761\n", + " t1765 = torch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1765 = ltorch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1765 = prims.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1762\n", + " t1766 = torch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1766 = ltorch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1766 = prims.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1763\n", + " t1767 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1767: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1782 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1782: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1797 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1797: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1764\n", + " t1799 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1799: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1765\n", + " t1768 = torch_slice_prim_impl(t1767, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1768: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1769 = torch_slice_prim_impl(t1767, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1769: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1783 = torch_slice_prim_impl(t1782, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1783: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1784 = torch_slice_prim_impl(t1782, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1784: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1772, t1787] = nvFusion76(t1767, t1769, t1782, t1784)\n", + " # t1770 = prims.convert_element_type(t1769, dtypes.float32) # t1770: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1771 = prims.neg(t1770) # t1771: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1772 = prims.convert_element_type(t1771, dtypes.bfloat16) # t1772: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1785 = prims.convert_element_type(t1784, dtypes.float32) # t1785: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1786 = prims.neg(t1785) # t1786: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1769, t1784\n", + " t1788 = torch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1788 = ltorch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1788 = prims.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1787, t1783\n", + " t1773 = torch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1773 = ltorch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1773 = prims.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1772, t1768\n", + " [t1781, t1796] = nvFusion77(t154, t157, t1767, t1773, t1782, t1788)\n", + " # t1775 = prims.convert_element_type(t1767, dtypes.float32) # t1775: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1790 = prims.convert_element_type(t1782, dtypes.float32) # t1790: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1791 = prims.mul(t1790, t154) # t1791: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1793 = prims.convert_element_type(t1788, dtypes.float32) # t1793: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1794 = prims.mul(t1793, t157) # t1794: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1795 = prims.add(t1791, t1794) # t1795: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1796 = prims.convert_element_type(t1795, dtypes.bfloat16) # t1796: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1776 = prims.mul(t1775, t154) # t1776: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1778 = prims.convert_element_type(t1773, dtypes.float32) # t1778: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1779 = prims.mul(t1778, t157) # t1779: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1780 = prims.add(t1776, t1779) # t1780: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1781 = prims.convert_element_type(t1780, dtypes.bfloat16) # t1781: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1767, t1773, t1782, t1788\n", + " t1800 = torch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1800 = ltorch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1800 = prims.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1796, t1799\n", + " t1798 = torch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1798 = ltorch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1798 = prims.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1781, t1797\n", + " (t1801, t1802, t1803, t1804, _, _, t1805, t1806, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1798, t1800, t1766, 0.0, True, scale=0.08838834764831843)\n", + " t1808 = torch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1808 = ltorch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1808 = prims.transpose(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1809 = torch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1809 = ltorch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1809 = prims.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1808\n", + " t1810 = torch.nn.functional.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1810 = ltorch.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1810 = prims.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1814, t1821, t1829] = nvFusion78(t1742, t1810, t1825)\n", + " # t1812 = prims.convert_element_type(t1742, dtypes.float32) # t1812: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1811 = prims.convert_element_type(t1810, dtypes.float32) # t1811: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1813 = prims.add(t1811, t1812) # t1813: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1814 = prims.convert_element_type(t1813, dtypes.bfloat16) # t1814: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1816 = prims.mul(t1813, t1813) # t1816: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1817 = prims.sum(t1816, (2,)) # t1817: \"cuda:0 f32[1, 512]\"\n", + " # t1818 = prims.broadcast_in_dim(t1817, [1, 512, 1], [0, 1]) # t1818: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1819 = prims.div(t1818, 4096.0) # t1819: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1820 = prims.add(t1819, 1e-05) # t1820: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1821 = prims.rsqrt(t1820) # t1821: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1822 = prims.broadcast_in_dim(t1821, (1, 512, 4096), (0, 1, 2)) # t1822: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1823 = prims.mul(t1813, t1822) # t1823: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1827 = prims.convert_element_type(t1825, dtypes.float32) # t1827: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1828 = prims.mul(t1823, t1827) # t1828: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1829 = prims.convert_element_type(t1828, dtypes.bfloat16) # t1829: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1831 = torch.nn.functional.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1831 = ltorch.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1831 = prims.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1830 = torch.nn.functional.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1830 = ltorch.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1830 = prims.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1845] = nvFusion79(t1830, t1831)\n", + " # t1832 = prims.convert_element_type(t1830, dtypes.float32) # t1832: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1833 = prims.neg(t1832) # t1833: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1834 = prims.exp(t1833) # t1834: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1835 = prims.add(1.0, t1834) # t1835: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1836 = prims.reciprocal(t1835) # t1836: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1840 = prims.mul(t1832, t1836) # t1840: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1843 = prims.convert_element_type(t1831, dtypes.float32) # t1843: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1844 = prims.mul(t1840, t1843) # t1844: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1845 = prims.convert_element_type(t1844, dtypes.bfloat16) # t1845: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1846 = torch.nn.functional.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1846 = ltorch.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1846 = prims.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1857, t1865] = nvFusion80(t1814, t1846, t1861)\n", + " # t1848 = prims.convert_element_type(t1814, dtypes.float32) # t1848: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1847 = prims.convert_element_type(t1846, dtypes.float32) # t1847: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1849 = prims.add(t1847, t1848) # t1849: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1852 = prims.mul(t1849, t1849) # t1852: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1853 = prims.sum(t1852, (2,)) # t1853: \"cuda:0 f32[1, 512]\"\n", + " # t1854 = prims.broadcast_in_dim(t1853, [1, 512, 1], [0, 1]) # t1854: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1855 = prims.div(t1854, 4096.0) # t1855: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1856 = prims.add(t1855, 1e-05) # t1856: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1857 = prims.rsqrt(t1856) # t1857: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1858 = prims.broadcast_in_dim(t1857, (1, 512, 4096), (0, 1, 2)) # t1858: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1859 = prims.mul(t1849, t1858) # t1859: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1863 = prims.convert_element_type(t1861, dtypes.float32) # t1863: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1864 = prims.mul(t1859, t1863) # t1864: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1865 = prims.convert_element_type(t1864, dtypes.bfloat16) # t1865: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1866 = torch.nn.functional.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " # t1866 = ltorch.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " # t1866 = prims.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " return {'output': t1866, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33, t34, t35, t36, t37, t38, t39, t40, t41, t42, t43, t44, t45, t46, t47, t48, t49, t50, t51, t52, t53, t54, t55, t56, t57, t58, t59, t60, t61, t62, t63, t64, t65, t66, t67, t68, t69, t70, t71, t72, t73, t74, t75, t76, t77, t78, t79, t80, t81, t82, t83, t84, t85, t86, t87, t88, t89, t90, t91, t92, t93, t94, t95, t96, t97, t98, t99, t100, t101, t102, t103, t104, t105, t106, t107, t108, t109, t110, t111, t112, t113, t114, t115, t116, t117], 'flat_output': (t1866,)}, ((t0, t10, t100, t1001, t101, t1010, t102, t103, t104, t1042, t1044, t1045, t1046, t1047, t1048, t1049, t105, t1050, t1053, t1054, t1058, t106, t1065, t1069, t107, t1073, t1074, t1075, t108, t1089, t109, t1090, t1094, t11, t110, t1101, t1105, t1109, t111, t1118, t112, t113, t114, t115, t1150, t1152, t1153, t1154, t1155, t1156, t1157, t1158, t116, t1161, t1162, t1166, t1173, t1177, t1181, t1182, t1183, t1197, t1198, t12, t1202, t1209, t1213, t1217, t122, t1226, t1258, t1260, t1261, t1262, t1263, t1264, t1265, t1266, t1269, t1270, t1274, t1281, t1285, t1289, t129, t1290, t1291, t13, t1305, t1306, t1310, t1317, t1321, t1325, t133, t1334, t1366, t1368, t1369, t137, t1370, t1371, t1372, t1373, t1374, t1377, t1378, t1382, t1389, t1393, t1397, t1398, t1399, t14, t1413, t1414, t1418, t1425, t1429, t1433, t1442, t146, t1474, t1476, t1477, t1478, t1479, t1480, t1481, t1482, t1485, t1486, t1490, t1497, t15, t1501, t1505, t1506, t1507, t1521, t1522, t1526, t1533, t1537, t154, t1541, t1550, t157, t1582, t1584, t1585, t1586, t1587, t1588, t1589, t1590, t1593, t1594, t1598, t16, t1605, t1609, t1613, t1614, t1615, t1629, t1630, t1634, t1641, t1645, t1649, t1658, t1690, t1692, t1693, t1694, t1695, t1696, t1697, t1698, t17, t1701, t1702, t1706, t1713, t1717, t1721, t1722, t1723, t1737, t1738, t1742, t1749, t1753, t1757, t1766, t178, t1798, t18, t180, t1800, t1801, t1802, t1803, t1804, t1805, t1806, t1809, t181, t1810, t1814, t182, t1821, t1825, t1829, t183, t1830, t1831, t184, t1845, t1846, t185, t1857, t186, t1861, t1865, t189, t19, t190, t194, t20, t201, t205, t209, t21, t210, t211, t22, t225, t226, t23, t230, t237, t24, t241, t245, t25, t254, t26, t27, t28, t286, t288, t289, t29, t290, t291, t292, t293, t294, t297, t298, t3, t30, t302, t309, t31, t313, t317, t318, t319, t32, t33, t333, t334, t338, t34, t345, t349, t35, t353, t36, t362, t37, t38, t39, t394, t396, t397, t398, t399, t4, t40, t400, t401, t402, t405, t406, t41, t410, t417, t42, t421, t425, t426, t427, t43, t44, t441, t442, t446, t45, t453, t457, t46, t461, t47, t470, t48, t49, t5, t50, t502, t504, t505, t506, t507, t508, t509, t51, t510, t513, t514, t518, t525, t529, t533, t534, t535, t549, t550, t554, t561, t565, t569, t578, t6, t610, t612, t613, t614, t615, t616, t617, t618, t621, t622, t626, t633, t637, t641, t642, t643, t657, t658, t662, t669, t673, t677, t686, t7, t718, t720, t721, t722, t723, t724, t725, t726, t729, t730, t734, t741, t745, t749, t750, t751, t765, t766, t770, t777, t781, t785, t794, t8, t826, t828, t829, t830, t831, t832, t833, t834, t837, t838, t842, t849, t85, t853, t857, t858, t859, t86, t87, t873, t874, t878, t88, t885, t889, t89, t893, t9, t90, t902, t91, t92, t93, t934, t936, t937, t938, t939, t94, t940, t941, t942, t945, t946, t95, t950, t957, t96, t961, t965, t966, t967, t97, t98, t981, t982, t986, t99, t993, t997), (False, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 0.0, 4096.0, 4096.0, 0.08838834764831843, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))" ] }, - "execution_count": 60, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "x = torch.randn(2, 2048, 4096, device=\"cuda\")\n", - "(tm(x) - m(x)).abs().max()\n" + "print(actual.grad_fn)\n", + "thunder.last_traces(thunder_model)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "558f2553-37f7-4b58-b7cd-a744155613a8", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Well, that is quite a bit to look through.\n", + "But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n", + "`thunder.last_backward_traces(thunder_model)[-1]`)." ] }, { "cell_type": "code", - "execution_count": 61, - "id": "a6f4b77c", + "execution_count": 10, + "id": "59643398-d6e2-4c32-81bd-145a1198b1f3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[# Constructed by Augmented forward pass\n", - " import thunder\n", - " import thunder.core.prims as prims\n", - " import torch\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()),\n", - " # Constructed by Transform for execution (took 2 milliseconds)\n", - " import torch\n", - " import torch.nn.functional\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = torch.nn.functional.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = ltorch.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = torch.nn.functional.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = ltorch.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " [t3, t5, t6, t7] = nvFusion0(t0, t1)\n", - " # t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()),\n", - " # Constructed by Update Call Context (took 0 milliseconds)\n", - " import torch\n", - " import torch.nn.functional\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = torch.nn.functional.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = ltorch.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = torch.nn.functional.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = ltorch.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " [t7] = nvFusion0(t0, t1)\n", - " # t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t0, t1, t7, t_proj_weight), ()),\n", - " # Constructed by Delete Last Used (took 0 milliseconds)\n", - " import torch\n", - " import torch.nn.functional\n", - " from thunder.executors.torchex import no_autocast\n", - " \n", - " @torch.no_grad()\n", - " @no_autocast()\n", - " def augmented_forward_fn(input, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", - " # input: \"cuda:0 f32[2, 2048, 4096]\" \n", - " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", - " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", - " t0 = torch.nn.functional.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = ltorch.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t0 = prims.linear(input, t_fc_1_weight, None) # t0: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t1 = torch.nn.functional.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = ltorch.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t1 = prims.linear(input, t_fc_2_weight, None) # t1: \"cuda:0 f32[2, 2048, 11008]\"\n", - " [t7] = nvFusion0(t0, t1)\n", - " # t2 = prims.neg(t0) # t2: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t3 = prims.exp(t2) # t3: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t4 = prims.add(1.0, t3) # t4: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t5 = prims.reciprocal(t4) # t5: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t6 = prims.mul(t0, t5) # t6: \"cuda:0 f32[2, 2048, 11008]\"\n", - " # t7 = prims.mul(t6, t1) # t7: \"cuda:0 f32[2, 2048, 11008]\"\n", - " t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " # t8 = prims.linear(t7, t_proj_weight, None) # t8: \"cuda:0 f32[2, 2048, 4096]\"\n", - " return {'output': t8, 'flat_args': [input, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((input, t0, t1, t7, t_proj_weight), ())]" + "tensor([[[ 0.4160, -0.4668, 1.1016, ..., 0.5430, 1.2656, 0.2891],\n", + " [ 0.3320, -0.0557, 1.7891, ..., 1.0703, 1.0078, 1.2266],\n", + " [ 0.6836, -0.2871, 0.9531, ..., 0.0806, 0.7070, 0.8477],\n", + " ...,\n", + " [ 0.7695, -0.1260, 0.7266, ..., 0.1118, -0.0238, -1.2656],\n", + " [-0.7773, -0.5547, -0.3047, ..., -0.1807, 0.1895, 0.6875],\n", + " [ 0.8867, 0.4766, 0.3984, ..., 0.0815, -0.0879, 0.3477]]],\n", + " device='cuda:0', grad_fn=)" ] }, - "execution_count": 61, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "thunder.last_traces(tm)[0]" + "actual" + ] + }, + { + "cell_type": "markdown", + "id": "17341d86-d4c9-46bd-ac5e-3a05da1ff72c", + "metadata": {}, + "source": [ + "Let us clean up a bit." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6ba7f715", + "metadata": {}, + "outputs": [], + "source": [ + "del actual, expected\n", + "import gc\n", + "gc.collect();" + ] + }, + { + "cell_type": "markdown", + "id": "0261eb11", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "But is it faster? Yes!" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bccec79b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", + "208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()\n", + "%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()" + ] + }, + { + "cell_type": "markdown", + "id": "1d31e7f8", + "metadata": {}, + "source": [ + "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "ce4217b7", + "execution_count": 13, + "id": "ecad9125-bbf2-42c8-b11c-23eed4a6cd8f", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "del m, thunder_model\n", + "import gc\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n" + ] + }, + { + "cell_type": "markdown", + "id": "49e3273c-99be-4370-9e59-121c00481b4e", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Distributed with Thunder\n", + "\n", + "Those Large Language Models are called Large for a reason, and memory in a single GPU is invariably small. So we need multiple.\n", + "\n", + "Happily Thunder sports an FSDP interface to use multiple cards in our box.\n", + "\n", + "You still need to setup the process group, but as far as the model is concerned,\n", + "\n", + "```python\n", + "model = thunder.jit(thunder.distributed.fsdp(model))\n", + "```\n", + "\n", + "is all you need. Because it is tricky to run multiprocessing from Notebooks, we write a small example into a file and run it though `torch-run`.\n", + "\n", + "Check out our LitGPT Thunder examples for complete distributed training and finetuning!" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "18dd3379", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting zero_to_thunder_fsdp_simple_example.py\n" + ] + } + ], + "source": [ + "%%writefile zero_to_thunder_fsdp_simple_example.py\n", + "from thunder.tests.lit_gpt_model import GPT, Config\n", + "import os\n", + "import torch, torch.distributed\n", + "import thunder, thunder.distributed\n", + "\n", + "# Create Model\n", + "# NOTE: We create the model on CPU.\n", + "device='cpu'\n", + "torch.set_default_dtype(torch.bfloat16)\n", + "cfg = Config.from_name('Llama-2-7b-hf')\n", + "cfg.n_layer = 8 # fewer layers\n", + "model = GPT(cfg)\n", + "\n", + "# Setup for distributed\n", + "torch.distributed.init_process_group(backend='nccl')\n", + "rank = int(os.environ[\"LOCAL_RANK\"])\n", + "\n", + "device = f\"cuda:{rank}\"\n", + "x = torch.randint(1, model.config.vocab_size, (1, 1024), device=device)\n", + "\n", + "# thunder.distributed.fsdp takes care of moving the parameter\n", + "# shard to the correct GPU for the current process.\n", + "model = thunder.jit(thunder.distributed.fsdp(model)) # <---------------------------------------\n", + "print(f\"rank {rank} computing\")\n", + "# Run the forward pass.\n", + "for i in range(10):\n", + " res = model(x)\n", + " res.sum().backward()\n" + ] + }, + { + "cell_type": "markdown", + "id": "97e8edbf-424d-49a7-8ed6-12cb5e5d65fc", + "metadata": {}, + "source": [ + "Now we can launch it. Note that you need two GPUs for this to run correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "2bad9b64", + "metadata": { + "scrolled": true, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] \n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", + "rank 1 computing\n", + "rank 0 computing\n" + ] + } + ], + "source": [ + "!torchrun --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py" + ] + }, + { + "cell_type": "markdown", + "id": "9c65e75d", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "So there. FSDP with just wrapping the model in `fsdp`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "4a6d7a20", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Extending Thunder\n", + "\n", + "But we promised that thunder is extensible. Let's find out what's up with that.\n", + "\n", + "Specifically, we will incorporate the fast rope embedding kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n", + "\n", + "In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f7639065", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "my_ex" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "my_ex = thunder.extend.OperatorExecutor('my_ex', version='0.0.1')\n", + "thunder.extend.register_executor(my_ex)" + ] + }, + { + "cell_type": "markdown", + "id": "2fe3b40b-c6e9-417c-ab7a-32606cee871a", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "For our base implementation, we take the code from [LitGPT's implementation](https://github.com/Lightning-AI/litgpt/blob/be6139e1fd4b240d253efd58124457496d23d173/litgpt/model.py#L355-L361)\n", + "\n", + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n", + "Because we will demonstrate Thunder's ability to divert functions in the model, we make a version here that will not be diverted." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3e74436b-d8eb-472b-9d6d-b6412378fde7", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "import lit_gpt\n", + "def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", + " head_size = x.size(-1)\n", + " x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n", + " x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)\n", + " rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)\n", + " roped = (x * cos) + (rotated * sin)\n", + " return roped.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "markdown", + "id": "a63595ab", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "### Registering operators\n", + "\n", + "Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n", + "\n", + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "247074b3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch, thunder\n", + "from thunder.tests.lit_gpt_model import GPT\n", + "from thunder import TensorProxy\n", + "\n", + "def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", + " return lit_gpt.model.apply_rope(x, cos, sin)\n", + "\n", + "def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", + " return TensorProxy(like=x)\n", + "\n", + "apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n", + " replaces=lit_gpt.model.apply_rope)" + ] + }, + { + "cell_type": "markdown", + "id": "d6b7d056", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Testing our new operator " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0ebd5dd1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 0.0\n" + ] + }, + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(x, t_1_cos, t_1_sin):\n", + " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", + " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", + " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", + " t2 = apply_rope(x, t_1_cos, t_1_sin) # t2: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", + " del x, t_1_cos, t_1_sin\n", + " return t2" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n", + "\n", + "def test_apply_rope(x, m):\n", + " return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n", + "\n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", + "\n", + "expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print(\"deviation:\", (expected - actual).abs().max().item())\n", + "\n", + "thunder.last_traces(thunder_apply_rope)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "8c620a38", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Optimized kernels\n", + "\n", + "But why did we do this? Well, we can now layer a faster implementation on top.\n", + "For this we take the [unsloth fast rope embedding](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rope_embedding.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "6e6d0b1e-ba14-43e5-b0d9-27c0e3b46879", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "import triton\n", + "import triton.language as tl\n", + "import torch\n", + "\n", + "MAX_FUSED_SIZE = 65536\n", + "next_power_of_2 = triton.next_power_of_2\n", + "\n", + "def calculate_settings(n):\n", + " BLOCK_SIZE = next_power_of_2(n)\n", + " if BLOCK_SIZE > MAX_FUSED_SIZE:\n", + " raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n", + " f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n", + " num_warps = 4\n", + " if BLOCK_SIZE >= 32768: num_warps = 32\n", + " elif BLOCK_SIZE >= 8192: num_warps = 16\n", + " elif BLOCK_SIZE >= 2048: num_warps = 8\n", + " return BLOCK_SIZE, num_warps\n", + "\n", + "@triton.heuristics({\"BACKWARD_PASS\": lambda args: args[\"BACKWARD_PASS\"],})\n", + "@triton.jit\n", + "def _rope_embedding(\n", + " Q, Q_row_stride,\n", + " cos, cos_row_stride,\n", + " sin, sin_row_stride,\n", + " seqlen, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS: tl.constexpr,\n", + " BLOCK_SIZE : tl.constexpr,\n", + "):\n", + " \"\"\"\n", + " Calculates the RoPE Embedding quickly\n", + " RoPE is Q * cos + rotate_half(Q) * sin\n", + " See our blog post for more info\n", + " \"\"\"\n", + " row_position = tl.program_id(0)\n", + " group_head_position = tl.program_id(1)\n", + " col_offsets = tl.arange(0, BLOCK_SIZE)\n", + " half_head_dim = head_dim // 2\n", + " mask = col_offsets < half_head_dim\n", + "\n", + " sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n", + " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", + " cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n", + " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", + "\n", + " if BACKWARD_PASS:\n", + " # See our blog post for more info.\n", + " sin1 = -sin1\n", + " pass\n", + "\n", + " head_start = group_head_position * group_size\n", + " head_end = min((head_start + group_size), n_heads)\n", + "\n", + " for i in range(head_start, head_end):\n", + " offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets\n", + " offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim\n", + "\n", + " # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n", + " Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n", + " Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n", + "\n", + " tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n", + " tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n", + " pass\n", + "pass\n", + "\n", + "\n", + "def fast_rope_embedding_forward(Q, cos, sin):\n", + " Q = Q.transpose(1, 2).clone()\n", + " cos, sin = cos.squeeze(), sin.squeeze()\n", + " batch, seq_len, n_heads, head_dim = Q.shape\n", + " Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n", + " n_rows, n_cols = Q.shape\n", + " assert(seq_len <= cos.shape[0])\n", + "\n", + " # [TODO] Changing blocksize to head_dim//2 seems to have\n", + " # some concurrency / un-deterministic issues.\n", + " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)\n", + " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", + " n_groups = triton.cdiv(n_heads, group_size)\n", + "\n", + " grid = (n_rows, n_groups, )\n", + " _rope_embedding[grid](\n", + " Q, Q.stride(0),\n", + " cos, cos.stride(0),\n", + " sin, sin.stride(0),\n", + " seq_len, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS = False,\n", + " BLOCK_SIZE = BLOCK_SIZE,\n", + " num_warps = num_warps,\n", + " )\n", + " Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)\n", + " return Q, (BLOCK_SIZE, num_warps) \n", + "\n", + "def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):\n", + " dY = dY.transpose(1, 2)\n", + " batch, seq_len, n_heads, head_dim = dY.shape\n", + " dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n", + " # Must be reshape not view\n", + " n_rows, n_cols = dY.shape\n", + "\n", + " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", + " n_groups = triton.cdiv(n_heads, group_size)\n", + "\n", + " grid = (n_rows, n_groups, )\n", + " _rope_embedding[grid](\n", + " dY, dY .stride(0),\n", + " cos, cos.stride(0),\n", + " sin, sin.stride(0),\n", + " seq_len, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS = True,\n", + " BLOCK_SIZE = BLOCK_SIZE,\n", + " num_warps = num_warps,\n", + " )\n", + " dY = dY.view(batch, seq_len, n_heads, head_dim)\n", + " dY = dY.transpose(1, 2) \n", + " return dY\n" + ] + }, + { + "cell_type": "markdown", + "id": "ed1e9be3-d1c9-4c4b-bf14-a025a03687ac", + "metadata": {}, + "source": [ + "We also define the corresponding meta functions." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d7e6612d-f1fc-497c-9d64-15ef99824086", + "metadata": {}, + "outputs": [], + "source": [ + "def fast_rope_embedding_forward_meta(Q, cos, sin):\n", + " batch, n_heads, seq_len, head_dim = Q.shape\n", + " n_rows, n_cols = batch*seq_len, n_heads*head_dim \n", + " assert(seq_len <= cos.shape[0])\n", + "\n", + " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)\n", + " return TensorProxy(like=Q), (BLOCK_SIZE, num_warps) \n", + "\n", + "def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):\n", + " return TensorProxy(like=dY)" + ] + }, + { + "cell_type": "markdown", + "id": "b70eba5f", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Register optimized operators\n", + "\n", + "Just like the `apply_rope` before, we can register operators for the optimized forward and backward." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "f8f1e77e", + "metadata": {}, + "outputs": [], + "source": [ + "unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward', \n", + " meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)\n", + "unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward', \n", + " meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)" + ] + }, + { + "cell_type": "markdown", + "id": "2426263d", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Implementations for operators\n", + "\n", + "Do we need to divert `apply_rope` again? No!\n", + "We can register the specialized kernel as an _implementation_ of our base `apply_rope` operator. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`apply_ropw`) in terms of our new operator - so it has the call signature of the `apply_rope`. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well, actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "6b5c8320", + "metadata": {}, + "outputs": [], + "source": [ + "def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", + " assert len(x.shape) == 4\n", + " res, *_ = unsloth_apply_rope_forward(x, cos, sin)\n", + " return res\n", + "\n", + "def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:\n", + " if len(x.shape) != 4:\n", + " return False\n", + " return (x.device.devicetype == thunder.devices.DeviceType.CUDA and\n", + " cos.device.devicetype == thunder.devices.DeviceType.CUDA and\n", + " cos.device.devicetype == thunder.devices.DeviceType.CUDA)\n", + "\n", + "my_ex.register_implementation(apply_rope,\n", + " checker=apply_rope_to_unsloth_checker,\n", + " execution_transform=apply_rope_to_unsloth)\n" + ] + }, + { + "cell_type": "markdown", + "id": "eec7c95a", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "So let us give it a try! Works great..." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "965ba1d7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 0.015625\n" + ] + }, + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def computation(x, t_1_cos, t_1_sin):\n", + " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", + " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", + " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", + " (t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)\n", + " del x, t_1_cos, t_1_sin\n", + " return t2" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", + "\n", + "expected = test_apply_rope(Q, m)\n", + "actual = thunder_apply_rope(Q, m)\n", + "print(\"deviation:\", (expected - actual).abs().max().item())\n", + "\n", + "thunder.last_traces(thunder_apply_rope)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "69a93d3d-3a88-4297-b330-23a7fff2c4b4", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "And this is also automatic when we instantiate a larger llama2-like model:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "7fff2522", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deviation: 5.960464477539062e-07\n" + ] + } + ], + "source": [ + "torch.set_default_dtype(torch.float32)\n", + "with torch.device('cuda'):\n", + " m = GPT(Config.from_name('llama2-like'))\n", + "\n", + "for p in m.parameters():\n", + " p.requires_grad_(False)\n", + "\n", + "thunder_model = thunder.jit(m, executors=(my_ex,) + thunder.get_default_executors())\n", + "\n", + "inp = torch.randint(1, m.config.vocab_size, (1, 128), device=\"cuda\")\n", + "actual = thunder_model(inp)\n", + "expected = m(inp)\n", + "\n", + "print(\"deviation:\", (actual - expected).abs().max().item())" + ] + }, + { + "cell_type": "markdown", + "id": "b538cb40", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "By peeking into the trace, we can see that it actually used the unsloth apply rope:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c260cb25", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[' (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',\n", + " ' (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',\n", + " ' (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',\n", + " ' (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'apply_rope' in s]" + ] + }, + { + "cell_type": "markdown", + "id": "0f6c0780", + "metadata": {}, + "source": [ + "### But what about the backward?\n", + "\n", + "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "7670a872", + "metadata": {}, + "outputs": [], + "source": [ + "from thunder.core.transforms import get_grad, put_grads\n", + "\n", + "def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):\n", + " res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)\n", + " grad_res = get_grad(res)\n", + " grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)\n", + " put_grads((x,), (grad_x,))\n", + " return res\n", + "\n", + "my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,\n", + " execution_transform=apply_rope_to_unsloth,\n", + " grad_transform=unsloth_apply_rope_grad \n", + " )\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "219dfaa4-cdef-47de-b60c-7c7c1642cb84", + "metadata": {}, + "source": [ + "Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.\n" + ] + }, + { + "cell_type": "markdown", + "id": "68226a4a-6ad8-43fb-b92f-c1e8eec6f13e", + "metadata": {}, + "source": [ + "And let us try our function using the optimized backward" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "ccc3ed63-ddc2-4b0e-bcd0-f77d66fefe9f", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "res deviation: 0.015625\n", + "grad deviation: 0.0078125\n" + ] + } + ], + "source": [ + "Q.requires_grad_()\n", + "\n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())\n", + "\n", + "expected = test_apply_rope(Q, m)\n", + "go = torch.ones_like(expected)\n", + "gr_expected, = torch.autograd.grad(expected, Q, go)\n", + "actual = thunder_apply_rope(Q, m)\n", + "gr_actual, = torch.autograd.grad(actual, Q, go)\n", + "\n", + "print(\"res deviation:\", (expected - actual).abs().max().item())\n", + "print(\"grad deviation:\", (gr_expected - gr_actual).abs().max().item())" + ] + }, + { + "cell_type": "markdown", + "id": "63cb61ee-c791-49d1-ba5c-3fe4b5b9a9d5", + "metadata": {}, + "source": [ + "And with `last_backward_traces` we can check that our module is using the unsloth backward:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "cd12ca02-6f06-4d88-b5b7-25c4c27dbc9a", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "# Constructed by Delete Last Used (took 0 milliseconds)\n", + "import torch\n", + "from thunder.executors.torchex import no_autocast\n", + "\n", + "@torch.no_grad()\n", + "@no_autocast()\n", + "def backward_fn(saved_for_backward, cotangents):\n", + " # saved_for_backward: \"Collection\" \n", + " # cotangents: \"Collection\" \n", + " C0, \\\n", + " _, \\\n", + " = saved_for_backward\n", + " clear_collection(saved_for_backward)\n", + " del saved_for_backward\n", + " t4, \\\n", + " = cotangents\n", + " clear_collection(cotangents)\n", + " del cotangents\n", + " t1, \\\n", + " t2, \\\n", + " = C0\n", + " clear_collection(C0)\n", + " del C0\n", + " t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4) # t3: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", + " del t1, t2, t4\n", + " return (t3, None, None)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "thunder.last_backward_traces(thunder_apply_rope)[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "2776d183-0232-495e-aa75-3b90e799c841", + "metadata": {}, + "source": [ + "### Comparing and exploring optimizations\n", + "\n", + "It is also straightforward to compare potential optimizations.\n", + "\n", + "Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the hardware used, software environment, or memory layout of the operands." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "a5e0ce05", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "eager\n", + "3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", + "thunder + unsloth\n", + "6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", + "thunder default (nvfuser)\n", + "1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "def test_apply_rope_copy(x, m):\n", + " return apply_rope_copy(x, m.cos, m.sin)\n", + "\n", + "test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", + "test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)\n", + "y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "\n", + "print(\"eager\")\n", + "%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", + "print(\"thunder + unsloth\")\n", + "%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", + "print(\"thunder default (nvfuser)\")\n", + "%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n" + ] + }, + { + "cell_type": "markdown", + "id": "08b8454f-c725-470c-92a5-56b2206af0e8", + "metadata": {}, + "source": [ + "That's it!\n", + "\n", + "## Conclusion\n", + "\n", + "To wrap up, we hope you got a taste of\n", + "\n", + "- Getting things going with Thunder:\n", + "\n", + " - Applying Thunder through `thunder.jit` and\n", + " - using FSDP by just wrapping the model in `thunder.distributed.fsdp` before compilation.\n", + "\n", + "- See what's going on inspecting traces:\n", + "\n", + " - `thunder.last_traces` for the forward traces,\n", + " - `thunder.last_backward_traces` for the backward,\n", + " \n", + "- Extending Thunder:\n", + "\n", + " - registering operators with the `OperatorExecutor`,\n", + " - defining implementations with custom forward and backward to include optimized kernels.\n", + "\n", + "Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the [Thunder section of the Lightning forums](https://lightning.ai/forums/c/thunder) or in the `#thunder` channel on the [PyTorch-Lightning slack](https://pytorch-lightning.slack.com/). \n", + "\n", + "Do check out our LitGPT studios and the other tutorial notebooks.\n" + ] } ], "metadata": { + "celltoolbar": "Slideshow", "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -285,7 +4199,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.10" } }, "nbformat": 4, diff --git a/requirements/docs.txt b/requirements/docs.txt index f12c617045..14615a08c4 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,7 +1,7 @@ sphinx ==5.3.0 myst-parser ==1.0.0 nbsphinx ==0.9.3 -ipython[all] ==8.22.1 +ipython[all] ==8.22.2 pandoc ==2.3 docutils >=0.16 sphinxcontrib-fulltoc ==1.2.0 diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt index d2f9d92e56..47a14902ca 100644 --- a/requirements/notebooks.txt +++ b/requirements/notebooks.txt @@ -1 +1 @@ -ipython[all] ==8.22.1 +ipython[all] ==8.22.2 diff --git a/requirements/test.txt b/requirements/test.txt index c6d95336ef..a1402bd69b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -8,7 +8,7 @@ pytest-timestamper ==0.0.9 graphviz ==0.20.1 fdm ==0.4.1 expecttest ==0.2.1 # for test_ddp.py -hypothesis ==6.98.15 # for test_ddp.py +hypothesis ==6.99.10 # for test_ddp.py numpy # for test_ops.py einops # for test_einops.py lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5 diff --git a/setup.py b/setup.py index 0f20566fcd..b0ee14e897 100755 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ def _load_py_module(fname, pkg="thunder"): def _load_requirements(path_dir: str, file_name: str = "requirements.txt") -> list: reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines()) - return list(map(str, reqs)) + return [r for r in list(map(str, reqs)) if "@" not in r] def _prepare_extras( @@ -43,13 +43,24 @@ def _prepare_extras( return extras +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: + """Load readme as decribtion.""" + path_readme = os.path.join(path_dir, "README.md") + with open(path_readme, encoding="utf-8") as fp: + text = fp.read() + # https://github.com/Lightning-AI/lightning-thunder/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png + github_source_url = os.path.join(homepage, "raw", version) + # replace relative repository path to absolute link to the release + # do not replace all "docs" as in the readme we replace some other sources with particular path to docs + text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") + return text + + about = _load_py_module("__about__.py") # https://packaging.python.org/discussions/install-requires-vs-requirements / -# keep the meta-data here for simplicity in reading this file... it's not obvious -# what happens and to non-engineers they won't know to look in init ... -# the goal of the project is simplicity for researchers, don't want to add too much -# engineer specific practices +# keep the meta-data here for simplicity in reading this file. it's not obvious +# what happens and to non-engineers they won't know to look in init. setup( name="lightning-thunder", version=about.__version__, @@ -60,7 +71,9 @@ def _prepare_extras( download_url="https://github.com/Lightning-AI/lightning-thunder", license=about.__license__, packages=find_packages(exclude=["thunder/tests", "docs"]), - long_description=about.__long_doc__, + long_description=_load_readme_description( + path_dir=_PATH_ROOT, homepage=about.__homepage__, version=about.__version__ + ), long_description_content_type="text/markdown", include_package_data=True, zip_safe=False, diff --git a/thunder/__about__.py b/thunder/__about__.py index d9fb64ba1d..15e838ef4f 100644 --- a/thunder/__about__.py +++ b/thunder/__about__.py @@ -1,21 +1,17 @@ -__version__ = "0.0.0dev" +__version__ = "0.1.0" __author__ = "Lightning-AI et al" __author_email__ = "community@lightning.ai" __license__ = "Apache 2.0" __copyright__ = f"2024 {__author__}" __homepage__ = "https://github.com/Lightning-AI/lightning-thunder" __docs__ = "Lightning Thunder project." -# todo: consider loading Readme here... -__long_doc__ = """ -Lightning Thunder is a deep learning compiler for PyTorch. -""" + __all__ = [ "__author__", "__author_email__", "__copyright__", "__docs__", - "__long_doc__", "__homepage__", "__license__", "__version__", diff --git a/thunder/__init__.py b/thunder/__init__.py index df817a6e43..5d6f698a76 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -1,6 +1,6 @@ from functools import wraps from typing import Any -from collections import defaultdict +from collections import defaultdict, namedtuple from collections.abc import Callable from collections.abc import Sequence from contextlib import contextmanager @@ -15,6 +15,7 @@ from thunder.core.options import ( INTERPRETATION_OPTIONS, resolve_interpretation_option, + resolve_sharp_edges_option, CACHE_OPTIONS, SHARP_EDGES_OPTIONS, ) @@ -277,6 +278,21 @@ def _recursive_jit_call_warning() -> None: ) +CacheEntry = namedtuple( + "CacheEntry", + [ + "prologue_fn", + "prologue_traces", + "computation_fn", + "computation_traces", + "epilogue_fn", + "epilogue_traces", + "backward_fn", + "backward_traces", + ], +) + + # This function will replace compile() (below) before RC1 # TODO RC1 Consider adding a debug_log parameter to control debug printing # TODO RC1 Consider renaming compile_options to additional_compile_options @@ -328,6 +344,11 @@ def jit( if additional_transforms is None: additional_transforms = [] + # TODO: verify that tutorials don't have false positives and enable warning by default + # # Make sharp_edges == warn default if not supplied and if in the general jit + # if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON and sharp_edges is None: + # sharp_edges = SHARP_EDGES_OPTIONS.WARN + # TODO RC1 Refine the compile data option to remove unused options cd = CompileData( fn=fn, @@ -386,9 +407,17 @@ def get_computation_and_inputs(*args, **kwargs): # Checks cache cs.last_trace_cache_start = time.time_ns() if (cd.cache_option is CACHE_OPTIONS.CONSTANT_VALUES) or (cd.cache_option is CACHE_OPTIONS.SYMBOLIC_VALUES): - for pro, pro_traces, comp, comp_traces, epilogue, epilogue_traces, backward_fn, backward_traces in reversed( - cs.interpreter_cache - ): + for cache_entry in reversed(cs.interpreter_cache): + ( + pro, + pro_traces, + comp, + comp_traces, + epilogue, + epilogue_traces, + backward_fn, + backward_traces, + ) = cache_entry try: cs.last_prologue_execution_start = time.time_ns() if epilogue: @@ -415,10 +444,11 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_computation_transformation_start = 0 cs.last_computation_transformation_stop = 0 - return inps, pro_to_epi, comp, epilogue, backward_fn + return cache_entry, inps, pro_to_epi if cd.cache_option is CACHE_OPTIONS.SAME_INPUT: if len(cs.interpreter_cache): + cache_entry = cs.interpreter_cache[0] ( pro, pro_traces, @@ -428,7 +458,7 @@ def get_computation_and_inputs(*args, **kwargs): epilogue_traces, backward_fn, backward_traces, - ) = cs.interpreter_cache[0] + ) = cache_entry cs.last_prologue_execution_start = time.time_ns() if epilogue: @@ -449,7 +479,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_prologue_traces = pro_traces cs.last_prologue = pro - return inps, pro_to_epi, comp, epilogue, backward_fn + return cache_entry, inps, pro_to_epi cs.cache_misses += 1 cs.last_trace_cache_stop = time.time_ns() @@ -503,6 +533,9 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_prologue_execution_stop = time.time_ns() computation_traces = [computation_trc] + cs.last_traces = computation_traces + backward_traces = [] + cs.last_backward_traces = backward_traces computation_trc = dce(computation_trc) computation_traces.append(computation_trc) @@ -524,10 +557,9 @@ def get_computation_and_inputs(*args, **kwargs): # thunder_backward may recursively call compile and wraps the result in a # torch.autograd.Function to support embedding of Thunder-compiled # functions in torch's Autograd - computation_trc, backward_trc = split_forward_backward( - computation_trc.python_callable(), cd, cs, *inps - ) - computation_traces.append(computation_trc) + computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) + # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces + # by split_forward_backward cs.last_computation_transformation_start = time.time_ns() @@ -547,23 +579,24 @@ def get_computation_and_inputs(*args, **kwargs): if backward_trc is not None: backward_fn = backward_trc.python_callable() - backward_traces = [backward_trc] else: backward_fn = None - backward_traces = [] # TODO RC1 Update the cache + cache_entry = CacheEntry( + pro, protraces, comp, extraces, epilogue, epilogue_traces, backward_fn, backward_traces + ) if cd.cache_option is not CACHE_OPTIONS.NO_CACHING: - cs.interpreter_cache.append( - (pro, protraces, comp, extraces, epilogue, epilogue_traces, backward_fn, backward_traces) - ) + cs.interpreter_cache.append(cache_entry) cs.last_computation_transformation_stop = time.time_ns() - cs.last_traces = [computation_trc] + extraces + cs.last_traces += extraces cs.last_prologue_traces = [prologue_trc] + protraces cs.last_prologue = pro - return inps, pro_to_epi, comp, epilogue, backward_fn + return cache_entry, inps, pro_to_epi + + cd.get_computation_and_inputs = get_computation_and_inputs @wraps(fn) def fn_(*args, **kwargs) -> Any: @@ -575,18 +608,18 @@ def fn_(*args, **kwargs) -> Any: cs.last_trace_host_start = time.time_ns() cs.calls += 1 - inps, pro_to_epi, comp, epilogue, backward_fn = get_computation_and_inputs(*args, **kwargs) + cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) cs.last_trace_host_execution_start = time.time_ns() - result = comp(*inps) + result = cache_entry.computation_fn(*inps) - if backward_fn: + if cache_entry.backward_fn: # Run the compiled forward function data_for_autograd, (saved_tensors, saved_other) = result # Connect produced tensors with PyTorch's autograd graph ThunderFunction.apply( - backward_fn, + cache_entry.backward_fn, saved_tensors, saved_other, data_for_autograd["flat_output"], @@ -594,21 +627,17 @@ def fn_(*args, **kwargs) -> Any: ) result = data_for_autograd["output"] - if epilogue: + if cache_entry.epilogue_fn: result, comp_to_epi = result - epilogue(*pro_to_epi, *comp_to_epi) + cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) cs.last_trace_host_execution_stop = time.time_ns() cs.last_computation_execution_stop = cs.last_trace_host_execution_stop - cs.last_executed = comp + cs.last_executed = cache_entry.computation_fn cs.last_trace_cache_stop = time.time_ns() cs.last_trace_host_stop = time.time_ns() - # Updates statistics - cs.last_executed = comp - cs.last_trace_host_stop = time.time_ns() - return result if isinstance(fn, pytorch.nn.Module): @@ -676,25 +705,31 @@ def compile_stats(fn) -> CompileStats | None: return getattr(fn, "_lc_cs", None) -# TODO We should remove compiledata.last_traces in favor of forward_last_traces and backward_last_traces -# TODO: should we return fw and bw from separate functions. The return type (list or tuple of lists) is not so nice -def last_traces(fn) -> list[TraceCtx] | tuple[list[TraceCtx], list[TraceCtx]]: +def last_traces(fn) -> list[TraceCtx]: """Obtains the list of computation traces that have been produced for the last run of the function. This is a list of traces mirroring the progression of transformations being applied to the trace (at index 0) that has been acquired from interpreting the user program. - If the function has forward and backward, a tuple of them is returned. + If the function has forward and backward, the forward is returned. """ cs = compile_stats(fn) if cs is None: raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") - if cs.forward_last_traces is not None and cs.backward_last_traces is not None: - return cs.forward_last_traces, cs.backward_last_traces if cs.last_traces is None: raise TypeError(f"{fn} doesn't seem to have been called yet.") return cs.last_traces +def last_backward_traces(fn) -> TraceCtx: + """Obtains the list of backward traces that have been produced for the last run of the function and the selected prologue.""" + cs = compile_stats(fn) + if cs is None: + raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") + if cs.last_backward_traces is None: + raise TypeError(f"{fn} doesn't seem to have been called yet.") + return cs.last_backward_traces + + def last_prologue_traces(fn) -> TraceCtx: """Obtains the list of prologue traces that have been produced for the last run of the function and the selected prologue.""" cs = compile_stats(fn) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 3bad00892d..9120584989 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -1,7 +1,5 @@ import os -import copy import time -import pprint import torch import functools @@ -11,26 +9,18 @@ import thunder from thunder.tests.lit_gpt_model import Config, GPT, Block -try: - from lightning.fabric.utilities.throughput import measure_flops +from lightning.fabric.utilities.throughput import measure_flops +from lightning.fabric.utilities import Throughput - # from lightning.fabric.utilities import Throughput - LIGHTNING_AVAILABLE = True -except: - LIGHTNING_AVAILABLE = False -world_size, local_rank, global_rank = None, None, None -if "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ: +world_size = int(os.environ.get("WORLD_SIZE", 1)) +local_rank = int(os.environ.get("LOCAL_RANK", 0)) +global_rank = int(os.environ.get("RANK", 0)) +if world_size > 1: torch_dist.init_process_group(backend="nccl") - world_size = int(os.environ["WORLD_SIZE"]) - local_rank = int(os.environ["LOCAL_RANK"]) - global_rank = int(os.environ["RANK"]) pg = torch_dist.distributed_c10d._get_default_group() - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - use_ddp = True -else: - device = torch.device("cuda", 0) +device = torch.device("cuda", local_rank) +torch.cuda.set_device(device) def configure_optimizers(model, weight_decay, learning_rate, betas, device_type): @@ -38,7 +28,9 @@ def configure_optimizers(model, weight_decay, learning_rate, betas, device_type) fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and device_type == "cuda" - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=betas, fused=use_fused) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas, fused=use_fused + ) return optimizer @@ -112,16 +104,21 @@ def __init__( if n_layers is not None: self.config.n_layer = n_layers - # Initialize the model and the optimizer + # Initialize the model + t0 = time.perf_counter() + print(f"Loading model with {self.config.__dict__}") self.model = self.init_model() - self.optimizer = configure_optimizers( - self.model, weight_decay, learning_rate, (beta1, beta2), device_type="cuda" - ) + print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") # Setup the distributed algorithm choices if self.distributed_mode != "none": self.model = self.setup_distributed() + # Initialize the optimizer after the model is sharded if using FSDP + self.optimizer = configure_optimizers( + self.model, weight_decay, learning_rate, (beta1, beta2), device_type="cuda" + ) + # Compile the model if self.compile not in ["eager", None]: self.model = self.setup_compile() @@ -140,13 +137,10 @@ def __init__( } def init_model(self): - print(f"Loading model with {self.config.__dict__}") - t0 = time.perf_counter() - with self.device: + init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device + with init_device: model = GPT(self.config) - model.to(dtype=torch.bfloat16) - print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - + model.to(dtype=torch.bfloat16) return model def setup_distributed(self): @@ -244,7 +238,7 @@ def pad_collate(batch): y_padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=-1) return x_padded, y_padded - train_data = DummyDataset(self.model.max_seq_length, self.dynamic) + train_data = DummyDataset(self.config.block_size, self.dynamic) train_dataloader = DataLoader( train_data, batch_size=self.micro_batch_size, num_workers=2, collate_fn=pad_collate ) @@ -252,24 +246,30 @@ def pad_collate(batch): return train_dataloader def calculate_model_flops(self): - input_ids, targets = next(self.train_data_iter) - input_ids = input_ids.to(device=self.device) - targets = targets.to(device=self.device) + meta = torch.device("meta") + device = self.device + self.device = meta + + # calculate flops on a meta-device model because we only care about the shapes and + # because the flops calculator installs hooks on the model + meta_model = self.init_model() - model_fwd = lambda: self.model(input_ids) + x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta) + model_fwd = lambda: meta_model(x) model_loss = lambda y: torch.nn.functional.cross_entropy( - y.reshape(-1, y.size(-1)), targets.reshape(-1), ignore_index=-1 + y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1 ) - if LIGHTNING_AVAILABLE: - self.perf_metrics["model_flops"] = measure_flops(self.model, model_fwd, model_loss) / 1e12 + self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss) + + self.device = device def train(self): t0 = None - # if global_rank in [0, None]: - # #Calculate the model FLOPs - # self.calculate_model_flops() - # Setup Perf Collection - # self.throughput = Throughput(window_size=10, world_size=world_size) + if global_rank in [0, None]: + # Calculate the model FLOPs + self.calculate_model_flops() + # Setup throughput Collection + self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size) if "transformerengine" in self.compile: import transformer_engine.pytorch as te @@ -327,45 +327,30 @@ def train(self): print( f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}" ) - - # if global_rank in [0, None] and i >=warmup_iter: - # self.throughput.update( - # time=(t1-t0), - # flops=self.model_flops, - # batches=i, - # samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), - # lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.model.max_seq_length), - # ) - - # metrics = self.throughput.compute() - # if i % 10 == 0: - # print(metrics) + if i >= self.warmup_iter: + self.throughput.update( + time=(t1 - t0), + flops=self.perf_metrics["model_flops"], + batches=i, + samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), + lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size), + ) if global_rank in [0, None]: # print(f"Total time: {(t1 - t0):.2f}s") - # print(f"Average time per iter: {((t1 - t0)*1000)/(max_iters-warmup_iter):.2f}ms") self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter) def add_perf_metrics(self): - # tokens_per_sec = total number of benchmarked iterations x global BS x block_size / total elapsed time (s) - # = global BS x block_size / (total elapsed time (s)/total number of benchmarked iterations) - # = global BS x block_size / average iter time (s) - self.perf_metrics["tokens_per_sec"] = ( - self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["average_iter_time"] - ) # tokens/s - if self.perf_metrics["model_flops"] is not None: - self.perf_metrics["model_flop_per_sec"] = ( - self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["average_iter_time"] - ) - if world_size is not None: - self.perf_metrics["model_flop_per_sec"] *= world_size + metrics = self.throughput.compute() + self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"]) + self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"]) self.perf_metrics["memory_used_GB"] = torch.cuda.max_memory_allocated() / 1e9 def add_model_info_to_metrics(self): if global_rank in [0, None]: self.perf_metrics["model_name"] = self.model_name self.perf_metrics["Num GPUS"] = world_size - self.perf_metrics["Seq Len"] = self.model.max_seq_length + self.perf_metrics["Seq Len"] = self.config.block_size self.perf_metrics["Micro BS"] = self.micro_batch_size self.perf_metrics["Global BS"] = self.global_batch_size self.perf_metrics["GA"] = self.gradient_accumulation_steps @@ -417,7 +402,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None benchmark.add_perf_metrics() print( - f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.model.max_seq_length}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" + f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.config.block_size}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" ) print( f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B" @@ -430,12 +415,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"Compiler: {benchmark.compile}") print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms") print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB") - print(f"Throughput (Tokens/s): {benchmark.perf_metrics['tokens_per_sec']:.02f} tokens/s") - print( - f"Normalized Throughput (Tokens/s/GPU): {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f} tokens/s/gpu" - ) - if benchmark.perf_metrics["model_flop_per_sec"] is not None: - print(f"Model TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec']:.02f} TFLOP/s") + print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}") + print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") + print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}") except Exception as error: # Helps catch OutOfMemory Errors and post processing of errors diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py index 6c9e5e889b..c46a699ca5 100644 --- a/thunder/benchmarks/distributed.py +++ b/thunder/benchmarks/distributed.py @@ -211,7 +211,7 @@ def parse_args() -> argparse.Namespace: # TODO Port these benchmarks to pytest (and targets.py) -# See https://github.com/Lightning-AI/lightning-thunder/issues/1404 +# See issue "Create distributed pytest benchmarks" if __name__ == "__main__": args = parse_args() diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index bd54735beb..9f02b23d35 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -289,7 +289,9 @@ def wrapper(*args, **kwargs): return wrapper -# To compare with PyTorch and torchcompile +# To compare with PyTorch and raw torch.compile (i.e. not through thunder). The +# latter can help us isolate whether it's something we need to fix ourself or +# report upstream. torch_fwd_bwd = partial(thunder_fwd_bwd, compile_fn=torch_executor) torchcompile_fwd_bwd = partial(thunder_fwd_bwd, compile_fn=torch_compile_executor) @@ -432,7 +434,8 @@ def test_nanogpt_gelu_grad(benchmark, executor: Callable): # TODO Improve cross entropy's fwd+bwd perf when using the PyTorch executor -# See https://github.com/Lightning-AI/lightning-thunder/issues/1319 +# See "torch.cross_entropy implementation has incorrect dtype metadata + bwd +# is very slow" @pytest.mark.parametrize( "executor,", fwd_executors, @@ -454,7 +457,8 @@ def test_nanogpt_cross_entropy_fwd(benchmark, executor: None | Callable): # TODO Improve cross entropy's fwd+bwd perf when using the PyTorch executor -# See https://github.com/Lightning-AI/lightning-thunder/issues/1319 +# See "torch.cross_entropy implementation has incorrect dtype metadata + bwd +# is very slow" @pytest.mark.parametrize( "executor,", (grad_executors + apex_grad_executors), @@ -476,7 +480,8 @@ def test_nanogpt_cross_entropy_grad(benchmark, executor: None | Callable): # TODO Improve cross entropy's fwd+bwd perf when using the PyTorch executor -# See https://github.com/Lightning-AI/lightning-thunder/issues/1319 +# See "torch.cross_entropy implementation has incorrect dtype metadata + bwd +# is very slow" @pytest.mark.parametrize( "executor,", (fwd_executors + cudnn_layernorm_fwd_executors), diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 6032034cc8..3fd192802f 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -19,9 +19,8 @@ import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, pyval, pytype, proxy, AnyProxy, Proxy import thunder.core.devices as devices -from thunder.core.script.noinline import noinline -# This file defines the operations in lightning.compile's "core" language. +# This file defines the operations in thunder.jit's "core" language. # # These operators are intended to be used when defining user-facing languages, like the torch or NumPy # languages. @@ -34,7 +33,6 @@ _clang_fn_set: set = set() -# TODO RC1 Remove noinline # Decorator that sets the core language context and registers the function class clangop: def __init__(self, *, method_name: None | str = None): @@ -42,7 +40,6 @@ def __init__(self, *, method_name: None | str = None): def __call__(self, fn: Callable) -> Callable: _fn = langctx(Languages.CLANG)(fn) - _fn = noinline(_fn) _clang_fn_set.add(_fn) if self.method_name is not None: @@ -1005,7 +1002,7 @@ def stride_order(a: TensorLike, order: None | Sequence[int] = None) -> TensorLik .. note:: - No other lightning.compile operations specify how their outputs are represented in memory, and lightning.compile + No other thunder.jit operations specify how their outputs are represented in memory, and thunder.jit does not model strides. This operation is an explicit directive to construct a dense, non-overlapping and strided tensor, but operations on that tensor do not have to preserve those properties. """ diff --git a/thunder/common.py b/thunder/common.py index fbdfaa0db7..67afb496a0 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -62,9 +62,7 @@ def __init__(self): self.last_interpreted_history = None # torch.autograd.Function specific data - self.primal_trace = None - self.forward_last_traces = None - self.backward_last_traces = None + self.last_backward_traces = None # Timing stats self.last_trace_host_start: int = -1 @@ -133,40 +131,6 @@ def last_computation_execution_time(self, /) -> int: return self._time_template(start, stop, "computation execution") -import thunder.core.script.frontend as script_frontend -import thunder.core.script.instrumentation as script_instrumentation -import thunder.core.script.passes as passes -import thunder.core.script.python_ir as python_ir - - -# Preprocesses function -# Currently tries to map torch.foo lookups to thunder.torch.foo lookups -@script_instrumentation.record -def preprocess(fn, is_module): - gr = script_frontend.acquire_method(fn.forward if is_module else fn) - passes.unroll_for_loops_and_inline_modules(gr) - if is_module: - ( - additional_param_names, - additional_param_values, - additional_return_names, - ) = passes.module_to_function(gr) - passes.strongly_inline_functions(gr) - passes.torch_to_thunder(gr) - - thunder_fn = python_ir.generate_function(gr) - if is_module: - thunder_fn._additional_param_names = additional_param_names - thunder_fn._additional_param_values = additional_param_values - thunder_fn._additional_return_names = additional_return_names - else: - thunder_fn._additional_param_names = None - thunder_fn._additional_param_values = None - thunder_fn._additional_return_names = None - - return thunder_fn - - # A class that holds data about the compiled object, including statistics about how it's been called # TODO Better document the module-related data the preprocessing harvests, # like additional_param_names @@ -189,13 +153,17 @@ def __init__( use_rematerialization: bool = False, debug_log: None | StringIO = None, compile_options: dict[str, Any] = {}, + get_computation_and_inputs: Callable | None = None, ): # Records whether we're using the thunder.jit() entrypoint or not # The thunder.jit() entrypoint introduces important architectural updates, - # but some components are still designed to work with older architectures for + # but some components are still designed to work with the older entrypoint # and are being temporarily maintained to facilitate their development. self.using_jit = using_jit + # runs prologues to get the compute/backward/epilogue function and inputs + self.get_computation_and_inputs = get_computation_and_inputs + # Resolves cache option self.cache_option = resolve_cache_option(cache_option) @@ -262,20 +230,9 @@ def __init__( self.num_constant_args = 0 self._processed_function: Callable - if disable_preprocessing: - self._processed_function = fn - else: - warnings.warn( - "please use thunder.jit if possible and upgrade and use thunder.jit if it is not yet possible" - ) - self._processed_function = preprocess(fn, is_module=self.is_module) - # TODO Revisit assuming parameters are const - if self.is_module: - self.additional_param_names = self.processed_function._additional_param_names - self.additional_param_values = self.processed_function._additional_param_values - self.additional_return_names = self.processed_function._additional_return_names - self.num_constant_args = len(self.additional_param_values) + assert disable_preprocessing, "please use thunder.jit if you need preprocessing" + self._processed_function = fn # Disallows overwriting processed_function @property @@ -287,7 +244,7 @@ def processed_function(self): def _unpack_inputs(fn, tracectx: TraceCtx, args, kwargs, *, rename_proxies: bool): tracectx.unpacking() - # Translates tensors, arrays, and dtypes to lightning.compile types + # Translates tensors, arrays, and dtypes to thunder.jit types # TODO Translate NumPy dtypes def translate(x: Any, *, name: str | None = None) -> Any: # NOTE Unpacking proxies @@ -371,85 +328,6 @@ def translate(x: Any, *, name: str | None = None) -> Any: return proxyargs, proxykwargs -class ThunderOptimizedModule(torch.nn.Module): # TOM - # todo: subclass nn.Module or forward things like .state_dict() to the - # model - def __init__(self, model, fn, tfn, additional_param_names, additional_param_values, additional_return_names): - super().__init__() - self._model = model - self._forward_fn = fn - self._tfn = tfn - - self._additional_param_values = additional_param_values - self._additional_param_names = additional_param_names - self._additional_return_names = additional_return_names - d = {k: i for i, k in enumerate(additional_param_names)} - self._additional_return_param_idxes = [d[k] for k in additional_return_names] - - def __call__(self, *args, **kwargs): - all_args = (*self._additional_param_values, *args) - res = self._forward_fn(*all_args, **kwargs) - if self._additional_return_names: - res, *additional_returns = res - assert len(self._additional_return_names) == len( - additional_returns - ), f"Number of expected additional return args {len(self._additional_return_names)=} does not match the actual number {len(additional_returns)=}" - for k, v, idx in zip( - self._additional_return_names, additional_returns, self._additional_return_param_idxes - ): - m = self._model - parts = k.split(".") - for p in parts[:-1]: - m = getattr(m, p) - setattr(m, parts[-1], v) - self._additional_param_values[idx] = v - return res - - @contextmanager - def no_sync(self): - """Context manager to disable gradient synchronization in data parallel mode. - - This context manager is intended to be used in conjunction with - :class:`torch.nn.parallel.DistributedDataParallel` to disable gradient - synchronization in the backward pass. It will not have any effect when - used with other modules. - - .. note:: - - This could lead to different accumulated gradients with ``torch.nn.parallel.distributed.DistributedDataParallel.no_sync``. - PyTorch's gradient synchronization is implemented by applying all-reduce to gradient buckets of ``torch.nn.Parameter.grad``. - Thus the ``no_sync`` context leads to :math:`\text{AllReduce} \\left( \\sum_{i = 0}^{\rm{num_grad_accum_steps}} g_i \right)`. - In contrast, this synchronizes accumulated gradients when exiting, leading to - :math:`\text{AllReduce} \\left( \\sum_{i = 0}^{\rm{num_grad_accum_steps - 1}} g_i \right) + \text{AllReduce}(g_{\rm{num_grad_accum_steps}})`. - - .. warning:: - - You must reuse this context manager in each group of gradient accumulation iterations since gradients will get synchronized - on context manager exit. For example: - - .. code-block:: python - - with model.no_sync(): - for _ in range(len(gradient_accumulation_iters)): - loss(model(x)).backward() # uses no-sync-backward trace - loss(model(x)).backward() # uses the regular backward trace - optimizer.step() - - """ - from thunder.distributed import ( - set_skip_data_parallel_grad_sync, - reset_skip_data_parallel_grad_sync, - _sync_grads, - ) - - token = set_skip_data_parallel_grad_sync(True) - try: - yield - finally: - reset_skip_data_parallel_grad_sync(token) - _sync_grads(self) - - # # Caching objects and functions # @@ -458,8 +336,8 @@ def no_sync(self): # TODO Update cacheable types def _make_subkey_for(x: Any) -> tuple[bool, None | tuple]: - if isinstance(x, torch.Tensor): - return True, (torch.Tensor, x.shape, x.device, x.dtype, x.requires_grad) + if isinstance(x, (torch.Tensor, TensorProxy)): + return True, (type(x), x.shape, x.device, x.dtype, x.requires_grad) # TODO Add NumPy ndarray support if isinstance(x, np.ndarray): @@ -750,7 +628,7 @@ def _execute_trace( # Constructs the Python callable c = extrace.python_callable() - # TODO RC1 Remove this option (by modeling torch.compile as another executor) + # TODO RC1 Remove this option (by using the torch.compile executor) if compile_data.use_torch_compile: c = torch.compile(c) @@ -769,11 +647,10 @@ def _execute_trace( # TODO review functions which compute large objects unrelated to tensors and how # they're handled # TODO can the language context be detected from the inputs? -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/316 +# TODO: # Today all tensor outputs will be torch tensors, even if the input was NumPy arrays # provided in the NumPy language ctx -- what should the outputs be? Should we provide # a helper to convert torch tensors to NumPy arrays on output? -# TODO Provide an option to not preprocess (for debugging) def _create_callable( @@ -856,6 +733,9 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: # Resets use of compile flags cs.last_compile_reasons = defaultdict(list) with compile_data_and_stats(cd, cs): + traces: list[TraceCtx] = [] + cs.last_traces = traces + cs.last_backward_traces = [] # Determines whether to use autograd.Function or not # autograd.Function (which supports calling .backward() in PyTorch) is used when: # 1) The grad() transform is not applied @@ -886,7 +766,7 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_host_stop = time.time_ns() return result - # TODO Revisit compile() behavior when hit in a trace ctx + # TODO Revisit jit() behavior when hit in a trace ctx # This will inline the invocation of compile into the current # trace (UNLESS there was a cache hit, per above) # This interaction between the cache and tracing seems odd @@ -914,7 +794,7 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: # Starts recording a sequence of traces (this is not inlined) trc: TraceCtx = trc_or_result - traces: list[TraceCtx] = [trc] + traces.append(trc) # Applies transforms for transform in transforms: @@ -954,16 +834,6 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_host_stop = time.time_ns() return result - if cd.is_module: - _fn = ThunderOptimizedModule( - cd.fn, - _fn, - cd.processed_function, - cd.additional_param_names, - cd.additional_param_values, - cd.additional_return_names, - ) - # NOTE is_module is False _fn._pfn = cd.processed_function _fn._lc_cd = cd diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 43513ec489..b6563bbe13 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -3597,7 +3597,7 @@ def _check_exc_match_handler(inst: dis.Instruction, /, stack: InterpreterStack, stack.append(isinstance(left, right)) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1523 +# TODO See issue "Fix COMPARE_OP handler" # https://docs.python.org/3.10/library/dis.html#opcode-COMPARE_OP @register_opcode_handler("COMPARE_OP") def _compare_op_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: @@ -4212,8 +4212,6 @@ def _jump_backward_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> # https://docs.python.org/3.11/library/dis.html#opcode-JUMP_BACKWARD_NO_INTERRUPT -# TODO: we currently ignore the NO_INTERRUPT part, -# https://github.com/Lightning-AI/lightning-thunder/issues/1631 @register_opcode_handler("JUMP_BACKWARD_NO_INTERRUPT", min_ver=(3, 11)) def _jump_backward_no_interrupt_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> int: assert type(inst.arg) is int @@ -4490,7 +4488,6 @@ def _load_global_handler( return check_and_append(stack, obj) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1525 # https://docs.python.org/3.11/library/dis.html#opcode-LOAD_METHOD @register_opcode_handler("LOAD_METHOD") def _load_method_handler( @@ -4524,7 +4521,6 @@ def _load_method_handler( stack.append(meth) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1661 # https://docs.python.org/3.11/library/dis.html#opcode-LOAD_NAME @register_opcode_handler("LOAD_NAME") def _load_name_handler( @@ -4567,7 +4563,6 @@ def _make_cell_handler(inst: dis.Instruction, /, frame: InterpreterFrame, **kwar frame.localsplus[i] = c -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1526 # https://docs.python.org/3.10/library/dis.html#opcode-MAKE_FUNCTION @register_opcode_handler("MAKE_FUNCTION") def _make_function_handler( @@ -5077,7 +5072,6 @@ def do_raise(exc: Any = Py_NULL(), cause: Any = Py_NULL()) -> Literal[INTERPRETE return INTERPRETER_SIGNALS.EXCEPTION_RAISED -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1660 # https://docs.python.org/3.11/library/dis.html#opcode-PRINT_EXPR @register_opcode_handler("PRINT_EXPR") def _print_expr_handler( @@ -5350,7 +5344,6 @@ def impl(tos, name, tos1): return res -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/1552 # https://docs.python.org/3.10/library/dis.html#opcode-STORE_DEREF @register_opcode_handler("STORE_DEREF") def _store_deref_handler( @@ -5651,7 +5644,7 @@ def _send_handler( ) -> None | int | INTERPRETER_SIGNALS: # SEND(delta) # Equivalent to STACK[-1] = STACK[-2].send(STACK[-1]). Used in yield from and await statements. - # If the call raises StopIteration, pop the top value from the stack, push the exception’s value attribute, and increment the bytecode counter by delta. + # If the call raises StopIteration, pop the top value from the stack, push the exception's value attribute, and increment the bytecode counter by delta. assert isinstance(inst.arg, int) send_value = stack.pop() generator = stack[-1] @@ -6044,7 +6037,10 @@ def _impl(fn, *args, **kwargs): return _interpret_call(unbound_fn, slf, *args, **kwargs) # (2) Handles lookasides - lookaside_fn: None | Callable = compilectx.lookaside(fn, *args, **kwargs) + lookaside_fn: INTERPRETER_SIGNALS | None | Callable = compilectx.lookaside(fn, *args, **kwargs) + if lookaside_fn is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + # Happens with sharp edges, for example + return lookaside_fn if lookaside_fn: runtimectx.record_lookaside(lookaside_fn) res = lookaside_fn(*args, **kwargs) @@ -6330,7 +6326,7 @@ def _run_frame( assert len(frame.interpreter_stack) >= try_block.level + 3 with frame.interpreter_stack.set_cur_instruction(PseudoInst.EXCEPTION_HANDLER): del frame.interpreter_stack[try_block.level + 3 :] - exc_type = frame.interpreter_stack.pop() # we ignore that and asume == type(exc_value) + exc_type = frame.interpreter_stack.pop() # we ignore that and assume == type(exc_value) exc_value = frame.interpreter_stack.pop() exc_traceback = frame.interpreter_stack.pop() if exc_value != None: diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index e1c1f92826..e626a77f2a 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -193,7 +193,8 @@ def is_uncopyable(val: Any, /) -> bool: # - calling a function with a side effect (e.g. randn, print) # TODO RC1 What kind of error should a sharp edge raise? # TODO RC1 Improve sharp edges warnings and errors to show the source line -# https://github.com/Lightning-AI/lightning-thunder/issues/2099 +# See issue "jit: Improve "sharp edges" errors and warnings to show the sharp +# edge's source location" # Context for the minimal interpreter @@ -643,7 +644,6 @@ def decorator(fn: Callable): # general_jit lookasides # -# TODO Add all general_jit operation translations (see https://github.com/Lightning-AI/lightning-thunder/issues/1804) _general_jit_lookaside_map = {} @@ -884,14 +884,23 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable: lookaside = default_lookaside(fn, *args, **kwargs) if lookaside is None: - if is_opaque(fn) and fn not in _safe_functions: + + def is_from_torch(fn): + return hasattr(fn, "__module__") and fn.__module__ and fn.__module__.startswith("torch") + + if is_opaque(fn) and is_from_torch(fn): + if fn.__module__.startswith("torch._C"): + return lookaside + + # Torch functions have __name__ defined + fn_name = f"{fn.__module__}.{fn.__name__}" + + # For now, only torch-like opaque functions are sharp edges return _general_jit_sharp_edge( - f"Trying to call opaque function {extract_callable_name(fn)}, but it's unsupported. Please file an issue requesting supporting.", + f"Trying to call function {fn_name}, but it's unsupported. Please file an issue requesting support.", None, ) - return None - return lookaside @@ -917,9 +926,15 @@ def _general_jit_const_callback(value: Any) -> WrappedValue: # TODO(nikitaved): maybe call it upon Frame creation def _maybe_update_proxy_name(orig_value: Any, name: str): + # Names that we do not re-name proxies into as these are reserved + proxy_rename_ignore_names = { + "fn", # For example, `fn = globals()['__function_obj']` in prologue + "obj", # For example, `obj = fn.forward` in prologue + } + uvalue = unwrap(orig_value) - if isinstance(uvalue, Proxy) and is_proxy_name_available(name): + if isinstance(uvalue, Proxy) and (name not in proxy_rename_ignore_names) and is_proxy_name_available(name): uvalue_var = variableify(uvalue) rename_proxy_swapmap = get_general_jit_ctx()._proxy_swapmap if uvalue_var not in rename_proxy_swapmap: @@ -927,8 +942,11 @@ def _maybe_update_proxy_name(orig_value: Any, name: str): rename_proxy_swapmap[uvalue_var] = uvalue_renamed -def _apply_trace_proxy_rename(trace: TraceCtx, name: None | str = None) -> TraceCtx: - rename_proxy_swapmap = get_general_jit_ctx()._proxy_swapmap +def _apply_trace_proxy_rename( + trace: TraceCtx, rename_proxy_swapmap: None | dict[Variable, Proxy] = None, name: str | None = None +) -> TraceCtx: + if rename_proxy_swapmap is None: + rename_proxy_swapmap = get_general_jit_ctx()._proxy_swapmap new_trace = from_trace(trace) @@ -966,24 +984,7 @@ def proxy_name_replacer(arg: Any): def _general_jit_global_callback(orig_value: Any, name: str) -> Any: _maybe_update_proxy_name(orig_value, name) - # Allows loading the torch module - value = orig_value - if ( - value is torch - or (value is torch.nn.modules.module._global_backward_pre_hooks) - or (value is torch.nn.modules.module._global_backward_hooks) - or (value is torch.nn.modules.module._global_forward_hooks) - or (value is torch.nn.modules.module._global_forward_pre_hooks) - or (value is torch.nn.functional) - or (value is thunder.core.proxies.get_langctx) - or (value is prop_lookaside_helper) - ): - return value - - return _general_jit_sharp_edge( - f"Tried to loading global {name}. Global support is limited.", - value, - ) + return orig_value _safe_provenance_inst = { @@ -1424,11 +1425,24 @@ def thunder_general_jit( if epilogue_trace: bind_inputs("epilogue", epilogue_trace, pro_to_epi + comp_to_epi, pro_to_epi_proxies + comp_to_epi_proxies) - with general_jit_ctx(ctx): - # TODO(nikitaved): update prologue/epilogue as well - computation_trace = _apply_trace_proxy_rename(computation_trace, "computation") - if epilogue_trace: - # TODO: is it safe to use current swapdict here? - epilogue_trace = _apply_trace_proxy_rename(epilogue_trace, "epilogue") + # Returns a new swapmap dictionary which has the keys (ctx._proxy_swapmap.key() & variableify(proxies)) + def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]: + proxy_swapmap = ctx._proxy_swapmap + proxy_vars = {variableify(p) for p in proxies} + common_vars = proxy_swapmap.keys() & proxy_vars + restricted_proxy_swapmap = {v: proxy_swapmap[v] for v in common_vars} + return restricted_proxy_swapmap + + # Update prologue trace by renaming proxies which are passed from prologue to the computation trace + prologue_trace = _apply_trace_proxy_rename(prologue_trace, restrict_proxy_swapmap(pro_to_comp_proxies)) + + # Update computation trace by renaming proxies which are in the ctx._proxy_swapmap + computation_trace = _apply_trace_proxy_rename(computation_trace, ctx._proxy_swapmap, "computation") + + # Update epilogue trace by renaming proxies which are passed to the epilogue trace from prologue and computation traces + if epilogue_trace: + epilogue_trace = _apply_trace_proxy_rename( + epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue" + ) return prologue_trace, computation_trace, epilogue_trace diff --git a/thunder/core/langctxs.py b/thunder/core/langctxs.py index 4f9031d64e..bcc12b6141 100644 --- a/thunder/core/langctxs.py +++ b/thunder/core/langctxs.py @@ -9,7 +9,7 @@ # Context variables, context managers, and helpers related to setting the language context. # The language context is a context variable that determines how methods on proxies are resolved. # For example, in NumPy, ndarray.size returns the number of elements in the array. In PyTorch, -# torch.Tensor.size(dim=None) returns the tenor's shape when dim is None, and the length of the +# torch.Tensor.size(dim=None) returns the tensor's shape when dim is None, and the length of the # specified dimension when dim specifies a dimension (using an integer offset). # diff --git a/thunder/core/options.py b/thunder/core/options.py index a71c6ac628..521af983eb 100644 --- a/thunder/core/options.py +++ b/thunder/core/options.py @@ -170,16 +170,16 @@ def resolve_sharp_edges_option(x: Any, /) -> SHARP_EDGES_OPTIONS: elif isinstance(x, str): seo = _str_to_sharp_edges_option(x) - if seo is None: - _unknown_option("sharp edges", _str_to_sharp_edges_options_map.keys(), "allow", x) - - if seo is SHARP_EDGES_OPTIONS.WARN: - warnings.warn( - f"The 'warn' sharp edges option is experimental and still in development. It may not work as expected." - ) - if seo is SHARP_EDGES_OPTIONS.ERROR: - warnings.warn( - f"The 'error' sharp edges option is experimental and still in development. It may not work as expected." - ) + if seo is None: + _unknown_option("sharp edges", _str_to_sharp_edges_options_map.keys(), "allow", x) + + if seo is SHARP_EDGES_OPTIONS.WARN: + warnings.warn( + f"The 'warn' sharp edges option is experimental and still in development. It may not work as expected." + ) + if seo is SHARP_EDGES_OPTIONS.ERROR: + warnings.warn( + f"The 'error' sharp edges option is experimental and still in development. It may not work as expected." + ) return seo diff --git a/thunder/core/prims.py b/thunder/core/prims.py index f477ee2778..035106bb30 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -2841,7 +2841,7 @@ def slice_meta( # NOTE: slice is named "slice_prim" and not "slice" because it conflicts with Python's "slice" builtin -slice_prim = make_prim(PrimIDs.SLICE, "slice", meta=slice_meta, tags=(OpTags.SHAPE_OP,)) +slice_prim = make_prim(PrimIDs.SLICE, "slice_prim", meta=slice_meta, tags=(OpTags.SHAPE_OP,)) def squeeze_meta(a: TensorProxy, /, dims: tuple[int, ...]) -> TensorProxy: diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index e4c0ba01e6..92138bfcbc 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -793,7 +793,7 @@ def __rxor__(self, other): # # Shift operations # - # Issue https://github.com/Lightning-AI/lightning-thunder/issues/594 + # Issue "Implement logical and arithmetic left and right shifts" # tracks implementing these def __lshift__(self, other): @@ -1461,6 +1461,11 @@ def __rmatmul__(self, other): # Transposes # + @property + def T(self): + method = resolve_method("T", self) + return method(self) + @property def mT(self): method = resolve_method("mT", self) diff --git a/thunder/core/script/algorithms.py b/thunder/core/script/algorithms.py deleted file mode 100644 index bced598172..0000000000 --- a/thunder/core/script/algorithms.py +++ /dev/null @@ -1,199 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, Mapping -import itertools -import textwrap -from typing import Generic, ParamSpec, TypeVar, cast - -import networkx as nx -from typing_extensions import Self - -from thunder.core.utils import OrderedSet - -__all__ = ("flatten_map", "sort_adjacent", "compute_condense_map") -P = ParamSpec("P") -T = TypeVar("T") - - -# ============================================================================= -# == nx.(Di)Graph, but with more safety ======================================= -# ============================================================================= -class TypedGraph(nx.Graph, Generic[T]): # type: ignore[misc, no-any-unimported] - def __init__(self, edgelist: Iterable[tuple[T, T]] = ()) -> None: - super().__init__() - self.add_edges_from(edgelist) - - @property - def nodes(self) -> Iterable[T]: - return cast(Iterable[T], super().nodes) - - @property - def edges(self) -> Iterable[tuple[T, T]]: - return cast(Iterable[tuple[T, T]], super().edges) - - @property - def connected_components(self) -> Iterable[set[T]]: - return cast(Iterable[set[T]], nx.connected_components(self)) - - def subgraph(self, nodes: Iterable[T]) -> Self: - return cast(Self, super().subgraph(nodes)) - - def to_undirected_class(self) -> type: - return TypedGraph[T] - - def to_directed_class(self) -> type: - return TypedDiGraph[T] - - -class TypedDiGraph(TypedGraph[T], nx.DiGraph): # type: ignore[misc, no-any-unimported] - def assert_directed_acyclic(self) -> None: - if not nx.is_directed_acyclic_graph(self): - cycle = "\n".join(f"{node}" for node, _ in nx.find_cycle(self)) - raise AssertionError(f"Cycle detected:\n{textwrap.indent(cycle, ' ' * 4)}") - - def to_undirected(self, *args: P.args, **kwargs: P.kwargs) -> TypedGraph[T]: - G = super().to_undirected(*args, **kwargs) - assert isinstance(G, TypedGraph) - return G - - def predecessors(self, n: T) -> Iterable[T]: - return cast(Iterable[T], super().predecessors(n)) - - -# ============================================================================= -# == Graph algorithms ========================================================= -# ============================================================================= -def flatten_map(mapping: Mapping[T, T]) -> Iterable[tuple[T, T]]: - """If one interprets the items as edges of a tree (or forest), return items for a tree of at most depth two. - - For example, `{1: 2, 2: 3, 4: 3, 5: 6}` flattens to `{1: 3, 2: 3, 4: 3, 5: 6}`. - """ - G = TypedDiGraph[T](((j, i) for i, j in mapping.items() if i != j)) - assert nx.is_directed_acyclic_graph(G) - for cluster in nx.connected_components(G.to_undirected()): - (root,) = (i for i in cluster if not G.in_degree(i)) - yield from ((i, root) for i in cluster if i != root) - - -def _extract_paths(G: TypedDiGraph[T]) -> Iterable[tuple[T, ...]]: - assert nx.is_connected(G.to_undirected()) - subgraph = TypedDiGraph[T]() - subgraph.add_nodes_from(G.nodes) - subgraph.add_edges_from(edge for edge, adjacent in nx.get_edge_attributes(G, "adjacent").items() if adjacent) - assert len(subgraph) == len(G) - subgraph.assert_directed_acyclic() - - for nodes in subgraph.to_undirected().connected_components: - path = subgraph.subgraph(nodes) - sorted_nodes = tuple(nx.topological_sort(path)) - assert len(sorted_nodes) == 1 or nx.is_simple_path(path, sorted_nodes) - yield sorted_nodes - - -def sort_adjacent(G: TypedDiGraph[T]) -> Iterable[T]: - """Sort nodes, respecting strong adjacency requirements and trying to sort return blocks to the end. - - If edges are annotated with `is_return` the annotation will be used; otherwise - terminal nodes will be inferred. - """ - (root,) = (node for node in G.nodes if not G.in_degree(node)) - is_return = nx.get_node_attributes(G, "is_return") or {node: not G.out_degree(node) for node in G.nodes} - assert len(is_return) == len(G) and any(is_return.values()) - - sort_map = {} - for primary_key, sorted_nodes in enumerate(_extract_paths(G)): - if sorted_nodes[0] is root: - primary_key = -1 - elif is_return[sorted_nodes[-1]]: - primary_key += len(G) - - sort_map.update({node: (primary_key, idx) for idx, node in enumerate(sorted_nodes)}) - - assert len(sort_map) == len(G) - yield from sorted(sort_map, key=lambda node: sort_map[node]) - - -def sort_adjacent_dfs(G: TypedDiGraph[T]) -> Iterable[T]: - """Alternate sorting formulation. Prioritizes program order over moving returns to the end. - - Unlike `sort_adjacent`, this order guarantees that at least one dependency will have - appeared before the current block. (`undo_ssa` seems to depend on this invariant.) - """ - paths = {sorted_nodes[0]: sorted_nodes for sorted_nodes in _extract_paths(G)} - - condensed = {} - for path_root, path in paths.items(): - condensed.update({node: path_root for node in path}) - - G_traverse = TypedDiGraph[T]((condensed[source], condensed[sink]) for source, sink in G.edges) - G_traverse.add_nodes_from(paths) - for i in nx.dfs_preorder_nodes(G_traverse): - yield from paths.pop(i) - - assert not paths - - -def compute_condense_map(edges: Iterable[tuple[T, T]]) -> dict[T, OrderedSet[T]]: - """Given a graph of identity relations (including unions and cycles), determine a minumum basis. - - A common construct that emerges from program loops is the statement "A is either A or B". However - if we eliminate the vacuous "A is A" component we reach the much more useful "A is B", which - allows us to replace a thorny union with a simple value. Similarly, we can eliminate chains of - equality expressions. ("C is B, B is A" becomes "C is A, B is A") - - At first this seems as simple as finding the roots of the graph, but consider the following: - "B is A, C is either B or D". B can be replaced with A, but C is the union of A and D. Critically, - B is NOT D, so simply assigning all non-roots the union of the roots is incorrect. - - This function uses an iterative method to distil the graph. Note that there is significant - simplification; the input can be an arbitrary directed **cyclic** graph (as long as at least one - node is not part of a cycle), but the output constituents are trees of at most depth two. - """ - G = TypedDiGraph(edges) - G.remove_edges_from(nx.selfloop_edges(G)) - - condense_map: dict[T, OrderedSet[T]] = {node: OrderedSet() for node in G} - for subgraph_nodes in G.to_undirected().connected_components: - subgraph = cast(TypedDiGraph[T], G.subgraph(subgraph_nodes)) - roots = OrderedSet(node for node in subgraph_nodes if not subgraph.in_degree(node)) - assert roots, subgraph.edges - - equality_edges = OrderedSet((node, node) for node in subgraph.nodes) - while True: - # Condense pairs in `equality_edges`. For example, given the - # following graph and `equality_edges`: - # 0 → 1 → 2 → 3 → 4 → 5 - # ↑┄──┘ - # - # equality_edges = {(0, 1), (3, 4)} - # - # After grouping we're left with: - # {0, 1} → 2 → {3, 4} → 5 - clusters: dict[T, T] = {} - for cluster in TypedGraph(equality_edges).connected_components: - # The choice of "canonical" value is arbitrary as long as it is consistent. - canonical = next(iter(cluster)) - clusters.update((i, canonical) for i in cluster) - - assert len(clusters) == len(subgraph) - reduced_edges = ((clusters[i], clusters[j]) for i, j in subgraph.edges) - reduced_subgraph = cast(TypedDiGraph[T], TypedDiGraph[T](reduced_edges)) # MyPy can't figure this out... - reduced_subgraph.remove_edges_from(nx.selfloop_edges(reduced_subgraph)) - num_equality_edges = len(equality_edges) - - # Condense chains. - equality_edges.update(reduced_subgraph.edges) - - # Condense loops. - for cycle in nx.simple_cycles(reduced_subgraph): - equality_edges.update(zip(cycle, itertools.chain(cycle[1:], cycle[:1]))) - - if len(equality_edges) == num_equality_edges: - # No progress has been made, exit loop. - break - - for root in roots: - for reachable in itertools.chain([root], *nx.dfs_successors(subgraph, root).values()): - condense_map[reachable].add(root) - - return condense_map diff --git a/thunder/core/script/frontend.py b/thunder/core/script/frontend.py deleted file mode 100644 index 62980449bf..0000000000 --- a/thunder/core/script/frontend.py +++ /dev/null @@ -1,684 +0,0 @@ -import collections -import functools -import dis -import inspect -import itertools -import sys -from typing import Optional, TypeVar -from collections.abc import Callable -from collections.abc import Iterable - -import networkx as nx - -from thunder.core.script.graph import ( - check_graph, - replace_values, - Block, - Graph, - MROAwareObjectRef, - Node, - NULL, - PhiValue, - SourceInformation, - Value, -) -from thunder.core.script.instrumentation import record -from thunder.core.script import parse, values -from thunder.core.script.protograph import ProtoBlock, ProtoGraph, ProtoGraphTransform -from thunder.core.script.protograph_passes import apply_protograph_passes -from thunder.core.script.python_ir_data import get_instruction, SUPPORTS_PREPROCESSING -from thunder.core.utils import debug_asserts_enabled, OrderedSet - -T = TypeVar("T") - - -class Super: - pass - - -class PruneEpilogues(ProtoGraphTransform): - """Remove the `POP_TOP, ..., JUMP_ABSOLUTE` blocks introduced during parsing. - - NOTE: This is only for `_bind_to_graph`. The reason is that it produces a - ProtoGraph with mismatched stacks. (Since we've pruned POP_TOP ops.) - This isn't a problem since `_bind_to_graph` is value based, however - it does make `_inter_block_edges` unsafe. - """ - - def _apply(self) -> ProtoGraph | None: - retain: dict[ProtoBlock, ProtoBlock] = {} - for protoblock in self.protograph: - if isinstance(protoblock, ProtoGraph): - breakpoint() - instructions = tuple(i for i, _ in protoblock.flow.symbolic) - if all(isinstance(i, parse.EpilogueFixup) for i in instructions): - assert all(i.opname == parse.POP_TOP for i in instructions[:-1]) - assert instructions[-1].opname == parse.JUMP_ABSOLUTE, instructions[-1] - continue - - retain[protoblock] = new_protoblock = ProtoBlock(protoblock.flow) - new_protoblock.uses.update(protoblock.uses) - - for old, new in retain.items(): - for target, jump in old.jump_targets: - if target not in retain: - ((target, _),) = target.jump_targets - assert target in retain - new.add_jump_target(retain[target], jump) - - if len(retain) != len(tuple(self.protograph)): - return ProtoGraph(retain.values(), provenance=(self.__class__, self.protograph)) - return None - - -def _bind_to_graph( - proto_graph: ProtoGraph, - func: Callable, - method_self: object | None = None, - mro_klass: type | None = None, -) -> Graph: - """Convert abstract value graph into a concrete Graph. - - The key nuance of this conversion is that the mapping from `AbstractValue` - to `Value` is contextual. The first time we "see" an `AbstractValue` it - maps to a `Value`. If we encounter it in any other block it maps to a - PhiValue and we need to set the proper connectivity. - - This is perhaps clearer with an example. Suppose you have an argument `x` - which is used by the root block and passed to the next block, and suppose - you have another value `y` which is created in the root block and passed to - the next block. In the abstract flow this is represented as: - ________ ___________ - `x` -> | Root | -`x`-> | Block 1 | -> ... - | `y` | -`y`-> | | - -------- ----------- - - On the other hand, `Graph` represents the same connectivity as: - ________ ___________ - `x` ←┈┈→ `𝜙x_0` -> | Root | -`𝜙x_0` ←┈┈→ `𝜙x_1` -> | Block 1 | -> ... - | `y` | -`y` ←┈┈┈┈┈→ `𝜙y_0` -> | | - -------- ----------- - - (This diagram does not show the reason for PhiValues: to accept multiple inputs.) - """ - # Peek at the signature and live objects to create Values. This is the - # *only* region where this is permitted. - # ========================================================================= - # TODO(robieta): Lazily generate specializations during runtime. - signature = inspect.signature(func) - func_globals = {**func.__builtins__, **func.__globals__, **{"super": Super()}} - - # NOTE: - # `inspect.signature` will expose parameters in intuitive order. However that - # is not necessarily how Python represents them internally. Specifically, varargs - # and varkwargs are moved to the end. This convention is load bearing (since it - # allows the interpreter index into a flat args array) so we must respect it - # here. (`func.__code__.co_varnames` is the canonical ordering.) - arg_ordered_parameters = func.__code__.co_varnames[: len(signature.parameters)] - source_file_name = inspect.getsourcefile(func) - source_start_line = func.__code__.co_firstlineno - if set(arg_ordered_parameters) != set(signature.parameters): - assert hasattr(func, "__wrapped__") - msg = f"({', '.join(arg_ordered_parameters)}) != ({', '.join(signature.parameters.keys())})" - raise NotImplementedError(msg) - - co_name = func.__code__.co_name - self_key: parse.VariableKey | None = None - self_value: Value | None = None - if method_self is not None: - self_key = parse.VariableKey(arg_ordered_parameters[0], parse.VariableScope.LOCAL) - self_value = Value(value=method_self, name=self_key.identifier, is_function_arg=True) - - get_initial_value_cache = {} - - def get_initial_value(key: parse.VariableKey, block: Block | None = None) -> Value: - if key in get_initial_value_cache: - v = get_initial_value_cache[key] - assert not ((block is None or block != v.block) and not (v.is_global or v.is_const or v.is_function_arg)) - return v - if key.is_const: - v = Value(value=key.identifier, is_const=True) - get_initial_value_cache[key] = v - return v - - elif key == self_key: - v = self_value - get_initial_value_cache[key] = v - return v - - name = key.identifier - assert isinstance(name, str) - if key.scope == parse.VariableScope.LOCAL: - if (p := signature.parameters.get(name)) is not None: - v = Value(typ=p.annotation, name=name, is_function_arg=True) - get_initial_value_cache[key] = v - return v - v = Value(value=NULL, name=name, block=block) - get_initial_value_cache[key] = v - return v - - if key.scope == parse.VariableScope.NONLOCAL: - msg = f"nonlocal variables are not supported but (key, name) = ({key}, {name}) found" - raise RuntimeError(msg) - - if key.scope == parse.VariableScope.GLOBAL: - try: - val = func_globals[name] - except KeyError: - raise ValueError(f"Could not resolve global variable: {name=}.") - v = Value(name=name, value=val, is_global=True) - get_initial_value_cache[key] = v - return v - - raise ValueError(f"Unhandled key: {key=}, name: {name=}") - - del func - # End live inspection region. - # ========================================================================= - assert proto_graph is proto_graph.link() - proto_graph = PruneEpilogues(proto_graph).apply(or_default=True) - blocks = {protoblock: Block() for protoblock in proto_graph} - blocks[proto_graph.root].jump_sources.append(None) - - # Block inputs require special handling since we may need to create `PhiValue`s. - input_conversions = {} - for protoblock, block in blocks.items(): - for key, abstract_value in protoblock.flow.begin_state: - abstract_value = abstract_value.identity - if protoblock is proto_graph.root: - value = get_initial_value(key, block=block) - if key.scope == parse.VariableScope.LOCAL and value.value is not NULL: - assert isinstance(abstract_value, values.ExternalRef), abstract_value - value = PhiValue([value], [None], block) - - elif key in protoblock.uses: - value = PhiValue([], [], block) - - else: - value = Value(value=NULL, block=block) - - input_conversions[(abstract_value, protoblock)] = value - - convert_cache = {} - - def convert(value: values.AbstractValue, protoblock: ProtoBlock, block: Block) -> Value: - value = value.identity - v = convert_cache.get((value, protoblock)) - if v is not None: - if ( - v.block != block - and block is not None - and not (v.is_global or v.is_function_arg or v.is_const or v.value == NULL) - ): - raise AssertionError("ohoh, this should not happen") - return v - - def _convert(value: values.AbstractValue, protoblock: ProtoBlock) -> Value: - assert not value.is_detail, value - if (out := input_conversions.get((value, protoblock), missing := object())) is not missing: - return out - - if isinstance(value, values.NonPyObject): - assert value.tag == values.NonPyObject.Tag.MISSING - return Value(value=NULL, block=block) - - elif isinstance(value, (values.IntermediateValue, values.CompositeValue, values.AbstractPhiValue)): - # For now we discard any information and just treat them as opaque. - # TODO(robieta): refine - return Value(block=block) - - elif isinstance(value, values.ExternalRef) and value.key.is_const: - return get_initial_value(value.key, block=block) - - raise ValueError(f"Cannot convert abstract value: {value}, {protoblock} {protoblock is proto_graph.root=}") - - v = _convert(value, protoblock) - convert_cache[(value, protoblock)] = v - return v - - def make_nodes(protoblock: ProtoBlock, block: Block) -> Iterable[Node]: - for instruction, node_flow in protoblock.flow.materialized.items(): - node = Node( - i=instruction, - inputs=[convert(v, protoblock, block) for v in node_flow.inputs], - outputs=[convert(v, protoblock, block) for v in node_flow.outputs], - ) - node.source_infos = [ - SourceInformation( - orig_file_name=source_file_name, - orig_line_no=instruction.line_no + source_start_line, - orig_end_line_no=instruction.line_no + source_start_line, - gen_line_no=instruction.line_no, - gen_end_line_no=instruction.line_no, - col_offset=0, - end_col_offset=999, - ), - ] - - for output in OrderedSet(node.outputs).difference(node.inputs): - if not (output.node or output.is_const or output.is_global): - # output.node can be populated when we deconstruct a previously constructed value (e.g. binary_idx into a tuple from build_tuple) - output.node = node - - if node.i.opname in ("LOAD_ATTR", "LOAD_METHOD"): - # Once we set `parent` (so PhiValue can traverse through it) - # we can prune these just like all other load instructions. - node.outputs[0].parent = node.inputs[0] - node.outputs[0].name = node.i.argrepr - continue - - elif node.i.opname == "CALL_FUNCTION": - # Note: `super` handling is not currently generic. Corner cases - # such as `super(**{})` or `super_alias = super; super_alias()` - # will not be correctly handled. - # TODO(robieta): handle `super` without load bearing names. - if node.i.arg == 0 and isinstance(node.inputs[0].value, Super): - assert self_value is not None, "super() called in free context" - node.outputs[0].value = MROAwareObjectRef(self_value, start_klass=mro_klass) - - elif node.i.opname == "FOR_ITER": - node.outputs[1].node = node - node.outputs[1].name = ".for_item_iter" - - yield node - - # First pass: populate nodes and jump targets. - for protoblock, block in blocks.items(): - block.nodes = list(make_nodes(protoblock, block)) - for target, _ in protoblock.jump_targets: - jump_target = blocks[target] - last_node = block.nodes[-1] - jump_target.jump_sources.append(last_node) - last_node.jump_targets.append(jump_target) - - # Second pass: link blocks. - for protoblock, block in blocks.items(): - block_values = { - k: v - for k, abstract_v in protoblock.flow.begin_state - if isinstance(v := convert(abstract_v, protoblock, block), PhiValue) - } - - block.block_inputs = list(OrderedSet(block_values.values())) - for parent in proto_graph.parents[protoblock]: - parent_state = dict(parent.flow.end_state) - for key, sink in block_values.items(): - source = convert( - parent_state.get(key, values.NonPyObject(values.NonPyObject.Tag.MISSING)), - parent, - block=blocks[parent], - ) - if source.value is not NULL and source not in sink.values: - sink.add_missing_value(v=source, jump_source=blocks[parent].nodes[-1]) - - # Third pass: specify block outputs once we know which Values are passed to another Block. - for protoblock, block in blocks.items(): - outputs = (convert(abstract_value, protoblock, block) for k, abstract_value in protoblock.flow.end_state) - block.block_outputs.update(v for v in outputs if v.phi_values) - - param_keys = tuple(parse.VariableKey(p, parse.VariableScope.LOCAL) for p in arg_ordered_parameters) - missing = { - k: v - for k in proto_graph.root.uses.difference(param_keys) - if k.scope == parse.VariableScope.LOCAL and (v := get_initial_value(k)).value is not NULL - } - assert not missing, f"missing params {missing}" - - gr = Graph(list(blocks.values())) - gr.local_variables_at_start = [get_initial_value(k) for k in param_keys] - - gr.co_name = co_name - # bound_args = [module.forward.__self__] - gr.self_value = self_value - gr.ismethod = self_value is not None - # deal with other flags? - # NESTED, GENERATOR, NOFREE, COROUTINE, ITERABLE_COROUTINE, ASYNC_GENERATOR - gr.co_flags = inspect.CO_OPTIMIZED | inspect.CO_NEWLOCALS - gr.co_argcount = 0 - gr.co_posonlyargcount = 0 - gr.co_kwonlyargcount = 0 - gr.func_defaults = [] - gr.func_kwdefaults = {} - for p in signature.parameters.values(): - if p.kind == inspect.Parameter.POSITIONAL_ONLY: - gr.co_argcount += 1 - gr.co_posonlyargcount += 1 - elif p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: - gr.co_argcount += 1 - elif p.kind == inspect.Parameter.KEYWORD_ONLY: - gr.co_kwonlyargcount += 1 - elif p.kind == inspect.Parameter.VAR_POSITIONAL: - gr.co_flags |= inspect.CO_VARARGS - elif p.kind == inspect.Parameter.VAR_KEYWORD: - gr.co_flags |= inspect.CO_VARKEYWORDS - else: - assert False, f"unknown parameter kind {p.kind}" - - if p.default is not inspect._empty: - if p.kind == inspect.Parameter.KEYWORD_ONLY: - gr.func_kwdefaults[p.name] = p.default - else: - gr.func_defaults.append(p.default) - return gr - - -def acquire_partial( - pfunc: functools.partial, - module: object | None = None, - mro_klass: type | None = None, -) -> Graph: - # This is complicated due to the semantics of calling Python functions. - # The partial wrapper does the following: - # def pfunc.__call__(*args, **kwargs): - # kw = pfunc.keywords.copy() - # kw.update(kwargs) - # return pfunc.func(*pfunc.args, *args, **kw) - - # This means: - # - positional partial_args are applied from the front and once - # they are bound, they are removed from the signature, - # - keyword only args get new defautls, - # - binding a positional arg as a keyword arg effectively (i.e. in how - # it can be set in calls) makes that arg and all args to the right - # keyword only. - # - things that cannot be bound to parameters may show up in varargs - # or kwargs parameters of the function. - - gr = acquire_method(pfunc.func, module, mro_klass) - gr.ensure_links() - - # first we shuffle positional args to kw only if they are in the kwargs of the partial - pos_param_names = [v.name for v in gr.local_variables_at_start[: gr.co_argcount]] - pos_param_names_to_idx = {n: i for i, n in enumerate(pos_param_names)} - kw_pos_param_idx = [pos_param_names_to_idx[k] for k in pfunc.keywords if k in pos_param_names_to_idx] - if kw_pos_param_idx: - # convert positional default args to kw ones - kw_pos_param_min = min(kw_pos_param_idx) - if kw_pos_param_min < gr.co_posonlyargcount: - raise TypeError( - f"cannot bin positional-only argument {pos_param_names[kw_pos_param_min]} as keyword in partial" - ) - - num_to_kw = gr.co_argcount - kw_pos_param_min - if gr.func_defaults: - to_kw = gr.func_defaults[-num_to_kw:] - del gr.func_defaults[-num_to_kw:] - to_kw_names = pos_param_names[-num_to_kw:] - gr.func_kwdefaults.update(zip(to_kw_names, to_kw)) - # convert positional args to kw only - gr.co_kwonlyargcount += num_to_kw - gr.co_argcount -= num_to_kw - - # deal with positional args. some will be mapped to concrete positional args, some might be added to varargs (*args) - if gr.ismethod: - arg_start = 1 - arg_count = gr.co_argcount - 1 - else: - arg_start = 0 - arg_count = gr.co_argcount - - args_to_bind = pfunc.args[:arg_count] - args_for_varargs = pfunc.args[arg_count:] - - # do we need to drop positional default args? - posarg_default_start = gr.co_argcount - len(gr.func_defaults) - posarg_default_to_delete = len(args_to_bind) + arg_start - posarg_default_start - if posarg_default_to_delete > 0: - gr.func_defaults = gr.func_defaults[posarg_default_to_delete:] - - bound_values = gr.local_variables_at_start[arg_start : arg_start + len(args_to_bind)] - del gr.local_variables_at_start[arg_start : arg_start + len(args_to_bind)] - - for bound_value, arg in zip(bound_values, args_to_bind): - bound_value.is_function_arg = False - bound_value.is_const = True - # TODO: check type? - bound_value.value = arg - gr.co_argcount -= 1 - if gr.co_posonlyargcount > 0: - gr.co_posonlyargcount -= 1 - - # handle keyword arguments to concrete parameters, collect in kwargs those for kw-varargs (**kwargs) - param_names_to_idx = { - v.name: i for i, v in enumerate(gr.local_variables_at_start[: gr.co_argcount + gr.co_kwonlyargcount]) - } - kwargs = {} - for argname, argvalue in pfunc.keywords.items(): - idx = param_names_to_idx.get(argname, -1) - if idx == -1: - kwargs[argname] = argvalue - continue - gr.func_kwdefaults[argname] = argvalue - - # for varargs and kwargs fed from partial we need the following prelude: - # TODO: (but maybe we should just have a prelude always for the consts, too...) - # if it has *varargs: - # TMP1 = LOAD_CONST partial_args_for_varargs (needs to be a tuple) - # varargs = TMP1 + varargs - # if it has **kwargs: - # TMP2 = LOAD_CONST partial_kwargs - # kwargs = partial_kwargs | kwargs - - if args_for_varargs or kwargs: - prelude = Block() - prelude.graph = gr - jump_node = Node(i=parse.ThunderInstruction.make_jump_absolute(None), inputs=[], outputs=[]) - jump_node.source_infos = [ - SourceInformation( - orig_file_name="", # filename? - orig_line_no=0, - orig_end_line_no=0, - gen_line_no=0, - gen_end_line_no=0, - col_offset=0, - end_col_offset=999, - ), - ] - - prelude.nodes.append(jump_node) - jump_target = gr.blocks[0] - assert jump_target.jump_sources[0] is None - jump_target.jump_sources[0] = jump_node - jump_node.jump_targets.append(jump_target) - prelude.jump_sources.append(None) - for i in jump_target.block_inputs: - assert i.jump_sources[0] is None - i.jump_sources[0] = jump_node - else: - prelude = None - - # handle *args (varargs) - if args_for_varargs: - if kw_pos_param_idx: - raise TypeError( - f"partial tried to bind {len(pfunc.args)} positional arguments, but only {arg_count} are allowed after keyword binding" - ) - if not (gr.co_flags & inspect.CO_VARARGS): - raise TypeError( - f"partial tried to bind {len(pfunc.args)} positional arguments, but only {arg_count} are allowed" - ) - # the variable for varargs is at gr.co_argcount + gr.co_kwonlyargcount - v_vararg_param = gr.local_variables_at_start[gr.co_argcount + gr.co_kwonlyargcount] - v_partial_varargs = Value(name="partial_varargs", value=tuple(args_for_varargs), is_const=True) - v_varargs_new = Value(name="varargs_with_partial", block=prelude) # type is tuple - pv = PhiValue([v_vararg_param], [None], block=prelude) - new_n = Node( - i=get_instruction(opname="BINARY_ADD", arg=None), - inputs=[v_partial_varargs, pv], - outputs=[v_varargs_new], - ) - # line number? - new_n.source_infos = [ - SourceInformation( - orig_file_name="", # filename? - orig_line_no=0, - orig_end_line_no=0, - gen_line_no=0, - gen_end_line_no=0, - col_offset=0, - end_col_offset=999, - ), - ] - prelude.nodes.insert(0, new_n) - prelude.block_outputs.add(v_varargs_new) - # replace v_vararg_param with v_varargs_new in remainder - replace_values(gr, {v_vararg_param: v_varargs_new}) - prelude.block_inputs.append(pv) - - # handle **kwargs - if kwargs: - if not (gr.co_flags & inspect.CO_VARKEYWORDS): - raise TypeError( - f"function does not have **kwargs but partial tries to bind unknown keywords {tuple(kwargs)}." - ) - - # the variable for varargs is at gr.co_argcount + gr.co_kwonlyargcount - v_kwvararg_param = gr.local_variables_at_start[ - gr.co_argcount + gr.co_kwonlyargcount + (1 if gr.co_flags & inspect.CO_VARARGS else 0) - ] - v_partial_kwvarargs = Value(name="partial_kwvarargs", value=kwargs, is_const=True) - v_kwvarargs_new = Value(name="kwvarargs_with_partial", block=prelude) # type is dict - pv = PhiValue([v_kwvararg_param], [None], block=prelude) - new_n = Node( - i=get_instruction(opname="BINARY_OR", arg=None), - inputs=[v_partial_kwvarargs, pv], - outputs=[v_kwvarargs_new], - ) - # line number? - new_n.source_infos = [ - SourceInformation( - orig_file_name="", # filename? - orig_line_no=0, - orig_end_line_no=0, - gen_line_no=0, - gen_end_line_no=0, - col_offset=0, - end_col_offset=999, - ), - ] - prelude.nodes.insert(-1, new_n) - prelude.block_outputs.add(v_kwvarargs_new) - # replace v_vararg_param with v_varargs_new in remainder - replace_values(gr, {v_kwvararg_param: v_kwvarargs_new}) - prelude.block_inputs.append(pv) - - if prelude: - gr.blocks.insert(0, prelude) - return gr - - -@functools.cache -def _construct_protograph(func): - """Protoblocks are parse level constructs, so it is safe to reuse them.""" - return apply_protograph_passes(ProtoGraph.from_code(func.__code__)) - - -@record -def acquire_method( - method: Callable, - module: object | None = None, - mro_klass: type | None = None, -) -> Graph: - assert SUPPORTS_PREPROCESSING, sys.version_info - if isinstance(method, functools.partial): - return acquire_partial(method, module, mro_klass) - if callable(method) and not inspect.ismethod(method) and not inspect.isfunction(method): - method = method.__call__ - - method_self, func = (method.__self__, method.__func__) if inspect.ismethod(method) else (None, method) - assert not inspect.ismethod(func) - - module = module or method_self - if mro_klass is None and module is not None: - mro_klass = type(module) - - gr = _bind_to_graph(_construct_protograph(func), func, method_self, mro_klass) - gr.source_start_line = 1 - try: - gr.source_lines, _ = inspect.getsourcelines(method) - except OSError: - gr.source_lines = ["# Failed to extract source."] - - gr.method = method - gr.module = module - gr.mro_klass = mro_klass - if debug_asserts_enabled(): - check_graph(gr) - return gr - - -def remove_unused_values(gr: Graph) -> None: - gr.ensure_links() - - def remove_value(v: Value) -> None: - for pv in v.phi_values: - bl = pv.block - pv.remove_value(v) - if not pv.values: - remove_value(pv) - bl.block_inputs.remove(pv) - if pv in bl.block_outputs: - bl.block_outputs.remove(pv) - - for i in gr.blocks[0].block_inputs: - if len(i.values) == 1 and i.values[0] is None: - remove_value(i) - - gr.blocks[0].block_inputs = [i for i in gr.blocks[0].block_inputs if len(i.values) != 1 or i.values[0] is not None] - - values_used = set() - - INDEX_OPS = {"BINARY_SUBSCR"} - - def mark_used(v: Value) -> None: - if v in values_used: - return - values_used.add(v) - if v.node and v.node.i.opname in INDEX_OPS: - for i in v.node.inputs: - mark_used(i) - if v.parent is not None: - mark_used(v.parent) - if isinstance(v, PhiValue): - for w in v.values: - mark_used(w) - - for bl in gr.blocks: - for n in bl.nodes: - if n.i.opname not in INDEX_OPS: - for i in n.inputs: - mark_used(i) - - for bl in gr.blocks: - for i in bl.block_inputs[:]: - if i not in values_used: - for v in i.values[:]: - if v is not None: - i.remove_value(v) - bl.block_inputs.remove(i) - bl.block_outputs = OrderedSet(o for o in bl.block_outputs if o in values_used) - for n in bl.nodes[:]: - if n.i.opname in INDEX_OPS and not any((o in values_used) for o in n.outputs): - bl.nodes.remove(n) - for i in gr.local_variables_at_start: - if i is not None: - i.phi_values = [pv for pv in i.phi_values if pv in values_used] - - for bl in gr.blocks: - for n in bl.nodes: - for o in n.outputs: - o.phi_values = [pv for pv in o.phi_values if pv in values_used] - - # remove things only used in current block (and not in own phi) from outputs - # TODO: think if this would obsolete the above - outputs_used = set() - for bl in gr.blocks: - for i in bl.block_inputs: - assert isinstance(i, PhiValue) - for v in i.values: - outputs_used.add(v) - for bl in gr.blocks: - bl.block_outputs = OrderedSet(o for o in bl.block_outputs if o in outputs_used) - - if debug_asserts_enabled(): - check_graph(gr) diff --git a/thunder/core/script/graph.py b/thunder/core/script/graph.py deleted file mode 100644 index 24c22746df..0000000000 --- a/thunder/core/script/graph.py +++ /dev/null @@ -1,819 +0,0 @@ -# This is a "TorchScript-like" graph representation of Python IR. -# The idea is that blocks are "simple blocks" in terms of the code flow graph, -# i.e. without branches -import collections -import copy -import enum -import inspect -import linecache -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Set, Union -from collections.abc import Iterable, Iterator, Sequence - -from thunder.core.script.instrumentation import InstrumentingBase -from thunder.core.script.parse import ThunderInstruction -from thunder.core.script.noinline import noinline -from thunder.core.utils import OrderedSet - -if TYPE_CHECKING: - import graphviz - -GraphObject = Union["Value", "Node", "Block"] - - -def assert_value(v: GraphObject | None) -> "Value": - assert isinstance(v, Value) - return v - - -def assert_node(n: GraphObject | None) -> "Node": - assert isinstance(n, Node) - return n - - -def assert_block(bl: GraphObject | None) -> "Block": - assert isinstance(bl, Block) - return bl - - -class GraphSummaryCallback: - def node(self, n: "Node") -> tuple[list[str], list[str]]: - return [], [] - - def finish(self) -> list[str]: - return [] - - -class NULL: - """marker for non-existant object.""" - - pass - - -@dataclass -class SourceInformation: - orig_line_no: int - orig_end_line_no: int - - gen_line_no: int - gen_end_line_no: int - # gen_file_name? --> could be interesting when passing SourceInfo to traces - - col_offset: int - end_col_offset: int - orig_file_name: str = "" - source: Any | None = None - - -class MROAwareObjectRef: # or as they call it super - def __init__(self, obj: Any, start_klass: type | None = None): - self.obj = obj - self.start_klass = start_klass - - def __getattr__(self, name: str) -> Any: - ## handle non-methods... - i = 0 - mro = inspect.getmro(self.obj.value.__class__) - if self.start_klass is not None: - while i < len(mro) and not mro[i] == self.start_klass: - i += 1 - i += 1 - while i < len(mro) and not hasattr(mro[i], name): - i += 1 - if i >= len(mro): - raise AttributeError(f"{name} not a member") - return getattr(mro[i], name) - - -# Represent undefined values e.g. non-existent attrs etc. -# this can be inserted as a (const) value and will then be -# translated into raising an error at runtime -class _Undefined: - def __init__(self, value, attr): - self.value = value - self.attr = attr - - -# Values are -# - function arguments as inputs to the graph (including self) -# - constants and globals -# - intermediate results / local variables -# - attributes of other values given in .parent -# they can be used -# - as inputs and outputs of nodes (but inplace is still tricky) -# - as block_outputs (note that block_outputs can be either outputs of nodes -# or attribute lookups). -# block_outputs (and only these) typically have .phi_values recorded. -# PhiValues are the block_inputs. -# - they have (one or multiple) block_outputs as .values, these are set at the -# .jump_sources (TODO: .jump_sources records None for non-node-generated). -# - There must be a 1-1 correspondence between .phi_values-> and .values->. -# All block_inputs (at least before an optimization pass towards the un-ssa-ing) -# are expected to be PhiValues and all PhiValues are expected to show up as -# block_inputs. -class Value(InstrumentingBase): - def __init__( - self, - *, - node: Optional["Node"] = None, - block: Optional["Block"] = None, - nr: int | None = None, - typ: type | None = None, - value: Any = None, - name: str | None = None, - parent: Optional["Value"] = None, - is_global: bool = False, - is_const: bool = False, - is_function_arg: bool = False, - ): - self.node = node - self.block = block - self.nr = nr - self.typ = typ if typ is not None or value in (None, NULL) else type(value) - self.value = value - self.name = name - self.parent = parent - self.is_global = is_global - self.is_const = is_const - self.is_function_arg = is_function_arg - self.phi_values: list["PhiValue"] = [] - assert not (block is None and not (is_global or is_const or is_function_arg)) - - def resolve(self) -> tuple["Value", ...]: - return (self,) - - def clone(self, translation_dict: dict[GraphObject, GraphObject] | None = None) -> "Value": - # clones a value, including (recursively) parent value - # uses translation_dict to look up parent value - # updates translation_dict - # does not register phi_values on the clone - # always clone parents? - if translation_dict is None: - translation_dict = {} - if self in translation_dict: - return assert_value(translation_dict[self]) - parent = self.parent - if parent: - if parent in translation_dict: - parent = assert_value(translation_dict[parent]) - else: - parent = parent.clone(translation_dict=translation_dict) - v = Value( - node=self.node, - block=self.block, - nr=self.nr, - typ=self.typ, - value=self.value, - name=self.name, - parent=parent, - is_global=self.is_global, - is_const=self.is_const, - is_function_arg=self.is_function_arg, - ) - if translation_dict is not None: - translation_dict[self] = v - return v - - def __str__(self, _value_printer=str) -> str: - parts = [] - if self.is_function_arg: - parts.append("funcarg") - if self.name: - parts.append(f"name={self.name}") - if self.typ is not None: - parts.append(f"typ={self.typ}") - if self.value is not None: - parts.append(f"value of type {type(self.value)}") - if self.is_const: - parts.append("const") - if self.is_global: - parts.append("global") - # if self.block is None: - # parts.append("block-None") - if self.parent is not None: - parts.append(f"parent={_value_printer(self.parent)}") - return f"""{type(self).__name__} {hex(id(self))} ({' '.join(parts)})""" - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - -class PhiValue(Value): - # node? - def __init__( - self, - values: list[Value], - jump_sources: Sequence[Optional["Node"]], - block: "Block", - _unfinished_clone: bool = False, - ): - super().__init__(block=block) - self.block: Block = block # duplicate assignment / declaration? - self._unfinished_clone = _unfinished_clone - self._set_values_jump_sourcess(values, jump_sources) - - def _set_values_jump_sourcess(self, values: list[Value], jump_sources: Sequence[Optional["Node"]]) -> None: - assert len(values) == len(jump_sources) - self.values = list(values) - if not self._unfinished_clone: - for v in self.values: - if v is not None: - v.phi_values.append(self) - self.jump_sources = list(jump_sources) - - def resolve(self) -> tuple[Value, ...]: - to_process = [self] - seen: OrderedSet[Value] = OrderedSet() - while to_process: - seen.add(v := to_process.pop()) - if isinstance(v, PhiValue): - to_process.extend(vi for vi in v.values if vi not in seen) - - return tuple(i for i in seen if not isinstance(i, PhiValue)) - - def clone(self, translation_dict: dict[GraphObject, GraphObject] | None = None) -> "PhiValue": - # due to loops in the Graph, this is complicated: - # we do not translate values or jump_sources here, but do - # translate blocks. - if translation_dict is None: - translation_dict = {} - if self in translation_dict: - v = translation_dict[self] - assert isinstance(v, PhiValue) - return v - v = PhiValue(self.values, self.jump_sources, assert_block(translation_dict[self.block]), _unfinished_clone=True) - translation_dict[self] = v - return v - - def post_process_clone(self, *, translation_dict: dict[GraphObject, GraphObject]) -> None: - assert self._unfinished_clone - self._unfinished_clone = False - self._set_values_jump_sourcess( - [assert_value(translation_dict.get(v, v)) for v in self.values], - [(assert_node(translation_dict.get(js, js)) if js is not None else None) for js in self.jump_sources], - ) - - def add_missing_value( - self, v: Value, idx: int | None = None, jump_source: Optional["Node"] = None - ) -> None: # None: append - if idx is None: - assert v not in self.values - self.values.append(v) - v.phi_values.append(self) - self.jump_sources.append(jump_source) - else: - assert 0 <= idx < len(self.values) - assert self.values[idx] is None - assert jump_source is None - self.values[idx] = v - v.phi_values.append(self) - - def remove_value(self, v: Value) -> None: - idx = self.values.index(v) - v.phi_values.remove(self) - del self.values[idx] - del self.jump_sources[idx] - - def replace_value(self, v_old: Value, v_new: Value) -> None: - if v_old is v_new: - return - - assert v_new not in self.values - idx = self.values.index(v_old) - self.values[idx] = v_new - assert (v_new.is_function_arg or v_new.is_const) or v_new.block.graph is self.block.graph # v_old.block.graph - if v_new.is_function_arg or v_new.is_const: - # TV-TODO: this is actually dubious for constants and we should avoid it - self.jump_sources[idx] = None - else: - self.jump_sources[idx] = v_new.block.nodes[-1] - - v_old.phi_values.remove(self) - v_new.phi_values.append(self) - - -# A node corresponds to one Python bytecode instruction given in .i -# it has Values as .inputs and .outputs -class Node(InstrumentingBase): - def __init__( - self, - *, - i: ThunderInstruction, - inputs: list[Value] | None = None, - outputs: list[Value] | None = None, - source_infos: list[SourceInformation] | None = None, - ): - self.i = i - self.inputs: list[Value] = inputs if inputs is not None else [] - self.outputs: list[Value] = outputs if outputs is not None else [] - self.jump_targets: list[Block] = [] - self.source_infos: list[SourceInformation] = source_infos if source_infos is not None else [] - self.block: Block | None = None - - def clone(self, translation_dict: dict[GraphObject, GraphObject] | None = None) -> "Node": - """.block of the clone will be None if block is not in translation dict.""" - if translation_dict is None: - translation_dict = {} - if self in translation_dict: - return assert_node(translation_dict[self]) - inputs = [i.clone(translation_dict=translation_dict) for i in self.inputs] - outputs = [o.clone(translation_dict=translation_dict) for o in self.outputs] - i = copy.copy(self.i) - n2 = Node(i=i, inputs=inputs, outputs=outputs) - n2.source_infos = copy.deepcopy(self.source_infos) - n2.jump_targets = [assert_block(translation_dict.get(bl, bl)) for bl in self.jump_targets] - if self.block is None: - n2.block = None - else: - bl2 = translation_dict.get(self.block) - assert bl2 is None or isinstance(bl2, Block) - n2.block = bl2 - translation_dict[self] = n2 - return n2 - - def set_jump_target(self, jt: "Block", idx: int | None = None) -> None: - # TODO: more validation? - # is_jump = (self.i.opname not in unconditional_jump_names) or (idx == 1) or (idx is None and self.jump_targets) - # assert is_jump - - if idx is None: - assert len(self.jump_targets) <= 1 - self.jump_targets.append(jt) - else: - old_jt = self.jump_targets[idx] - old_jt.jump_sources.remove(self) - self.jump_targets[idx] = jt - jt.jump_sources.append(self) - - def __str__(self) -> str: - # i.i.offset // 2, i.i.opname, i.i.arg, "(", i.i.argval, ")" - if self.i.opname in {"CALL_METHOD", "CALL_FUNCTION"}: - return f"{self.i.opname}({self.inputs})" - return f"{self.i.opname} {self.i.arg} ({self.i.argval})" # str(self.i) - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - -# Blocks have the first instruction (only) as the jump target -# (or the function entry point) -# Blocks always have a single final instruction that jumps (or RETURN) -# conditional jumps (including e.g. FOR_ITER) always have the non-jumping -# target first and then the jumping target. -# The jump targets are other blocks and are atributes of the jump instruction. -class Block: - def __init__(self): - self.jump_sources: list[Node | None] = [] - self.nodes: list[Node] = [] - self.block_inputs: list[Value] = [] - self.block_outputs = OrderedSet([]) - - def __str__(self) -> str: - return "\n".join([f" Block (reached from {self.jump_sources})"] + [" " + str(n) for n in self.nodes]) - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - def insert_node(self, n: Node, insert_after: Node | None = None, insert_before: Node | None = None) -> None: - assert n.block is None - assert (insert_after is None) != (insert_before is None), f"{insert_after=} {insert_before=}" - to_find = insert_after or insert_before - for idx, n2 in enumerate(self.nodes): - if n2 is to_find: - break - if n2 is not to_find: - raise ValueError(f"could not find node {n}") - - # validity checks? (also above) - n.block = self - if insert_after: - self.nodes.insert(idx + 1, n) - else: - self.nodes.insert(idx, n) - - -# A graph contains Blocks. -# The first block (.blocks[0]) is the entry point. Other blocks are connected -# through jump instructions. -class Graph(InstrumentingBase): - def __init__(self, blocks: list[Block] | None = None): - self.blocks = [] if blocks is None else blocks[:] - - def __str__(self) -> str: - return "\n".join(["Graph of"] + [str(b) for b in self.blocks]) - - def __repr__(self) -> str: - return f"{super().__repr__()[:-1]} {self}>" - - def nodes(self) -> Iterator[Node]: - for b in self.blocks: - yield from b.nodes - - def ensure_links(self) -> None: - for bl in self.blocks: - bl.graph = self - for n in bl.nodes: - n.block = bl - inps = set(n.inputs) - for o in n.outputs: - if o not in inps: # not for inplace - o.block = bl - o.node = n - for o in bl.block_outputs: - if not (o.is_const or o.is_function_arg): - o.block = bl - for i in bl.block_inputs: - i.block = bl - - def clone(self) -> tuple["Graph", dict[GraphObject, GraphObject]]: - bls2, translation_dict = clone_blocks(self.blocks) - g2 = Graph(blocks=bls2) - g2.local_variables_at_start = [v.clone() for v in self.local_variables_at_start] - replace_values(g2, {k: v for k, v in zip(self.local_variables_at_start, g2.local_variables_at_start)}) - g2.ismethod = self.ismethod - g2.co_name = self.co_name - g2.co_argcount = self.co_argcount - g2.co_flags = self.co_flags - g2.co_posonlyargcount = self.co_posonlyargcount - g2.co_kwonlyargcount = self.co_kwonlyargcount - g2.func_defaults = self.func_defaults[:] - g2.func_kwdefaults = self.func_kwdefaults.copy() - g2.method = self.method - g2.module = self.module - g2.mro_klass = self.mro_klass - g2.self_value = self.self_value - g2.source_start_line = self.source_start_line - g2.source_lines = self.source_lines[:] - - return g2, translation_dict - - def print(self) -> None: - value_counter = 1 - print(self.local_variables_at_start) - for bl in self.blocks: - for n in bl.nodes: - for o in n.outputs: - o.print_name = f"{o.name}:{value_counter}" if o.name is not None else f":{value_counter}" - value_counter += 1 - for i in n.inputs: - if not hasattr(i, "print_name"): - i.print_name = f"{i.name}:{value_counter}" if i.name is not None else f":{value_counter}" - value_counter += 1 - av = f"[{n.i.argval}]" if n.i.argval is not None else "" - print( - ",".join(o.print_name for o in n.outputs), - "=", - n.i.opname, - f"{av}(", - ", ".join([i.print_name for i in n.inputs]) + ")", - ) - - def summary(self, print_lines: bool = False, callback=GraphSummaryCallback()) -> None: - type_count = collections.Counter() - results = {} - - def get_name(v): - if v not in results: - idx = type_count[type(v)] - type_count[type(v)] += 1 - prefix = {PhiValue: "𝜙", Value: "V"}.get(type(v), type(v).__name__) - results[v] = (prefix, idx) - - # Populate cache - if isinstance(v, PhiValue): - _ = [get_name(vi) for vi in v.values] - if v.parent is not None: - _ = get_name(v.parent) - - return "{}_{}".format(*results[v]) - - graph_lines = [] - legend_lines = [] - - block_indices = {bl: i for i, bl in enumerate(self.blocks)} - block_jump_indices = {bl.nodes[-1]: i for i, bl in enumerate(self.blocks)} - block_jump_indices[None] = None - - for block in self.blocks: - graph_lines.extend( - ( - f"Block {block_indices[block]} reached from blocks {[block_jump_indices.get(js, 'unknown') for js in block.jump_sources]}", - f"Block inputs: {[get_name(i) for i in block.block_inputs]}", - f"Block outputs: {[get_name(i) for i in block.block_outputs]}", - ) - ) - for i, node in enumerate(block.nodes): - if ( - i == 0 - or node.source_infos - and ( - (not block.nodes[i - 1].source_infos) - or node.source_infos[-1] != block.nodes[i - 1].source_infos[-1] - ) - ): - line_no = node.source_infos[-1].gen_line_no - line = f"# l{line_no + self.source_start_line:3d} {self.source_lines[line_no].rstrip()}" - else: - line = "" - lines_before, lines_after = callback.node(node) - graph_lines.extend(lines_before) - graph_lines.append( - f" {node.i.opname:<20} {f'{[get_name(v) for v in node.inputs]} -> {[get_name(v) for v in node.outputs]}':<80} {line}" - ) - graph_lines.extend(lines_after) - graph_lines.append("") - graph_lines.extend(callback.finish()) - - for v, (prefix, idx) in sorted(results.items(), key=lambda x: x[1]): - values = f"[{', '.join(get_name(vi) for vi in v.values)}]" if isinstance(v, PhiValue) else "" - legend_lines.append(f"{prefix}_{idx} {v.__str__(_value_printer=get_name):<16} {values}") - - if print_lines: - print("\n".join(graph_lines) + "\n" + "\n".join(legend_lines)) - - return tuple(graph_lines), tuple(legend_lines) - - -def unify_values(values: list[Value], jump_sources: list[Node], bl: Block, all_predecessors_done: bool = True) -> Value: - if all_predecessors_done: - if len(values) == 1: - return values[0] - val = values[0] - if all(v is val for v in values[1:]): - return val - # different values - return PhiValue(values, jump_sources, bl) - - -def insert_before(new_n: Node, n: Node) -> None: - bl = assert_block(n.block) - idx = bl.nodes.index(n) - bl.nodes.insert(idx, new_n) - new_n.block = n.block - - -def insert_after(new_n: Node, n: Node) -> None: - bl = assert_block(n.block) - idx = bl.nodes.index(n) - bl.nodes.insert(idx + 1, new_n) - new_n.block = n.block - - -def replace_values(gr_or_bl: Graph | Block, value_map: dict[Value, Value], follow_phi_values: bool = False) -> None: - ### Replacing a value: - # - as inputs/outputs of nodes - # - value.parent for other values - # - phi nodes - # - graph input (?) / initial vars - processed = set() - - def map_values(v: Value) -> Value: - # do not call map_values without guarding for infinite recursion - if v in processed: - return value_map.get(v, v) - processed.add(v) - - if v in value_map: - if follow_phi_values: - for pv in v.phi_values[:]: - pv.replace_value(v, value_map[v]) - assert len(pv.values) == len(pv.jump_sources) - return value_map[v] - - if isinstance(v.value, MROAwareObjectRef): - v.value.obj = map_values(v.value.obj) - if v.parent is not None: - v.parent = map_values(v.parent) - if isinstance(v, PhiValue): - for ov in v.values: - nv = map_values(ov) - v.replace_value(ov, nv) - assert len(v.values) == len(v.jump_sources) - return v - - def process_block(bl: Block) -> None: - bl.block_inputs = [map_values(vv) for vv in bl.block_inputs] - for n in bl.nodes: - n.inputs = [map_values(vv) for vv in n.inputs] - n.outputs = [map_values(vv) for vv in n.outputs] - bl.block_outputs = OrderedSet(map_values(vv) for vv in bl.block_outputs) - - if isinstance(gr_or_bl, Graph): - for bl in gr_or_bl.blocks: - process_block(bl) - elif isinstance(gr_or_bl, Block): - process_block(gr_or_bl) - else: - raise TypeError("replace_values works on Graph or Block objects") - - -## TODO: our should this be a method? -def make_dot(gr: Graph, format: str = "png", add_names: bool = False) -> "graphviz.Digraph": - import graphviz - - dot = graphviz.Digraph(name="thunder_graph", format=format) - - block_idxes = {} - - value_idxes: dict[Value, int] = {} - - for i_bl, bl in enumerate(gr.blocks): - block_idxes[bl] = i_bl - with dot.subgraph(name=f"cluster_bl_{i_bl}") as sub_dot: - for i_i, i in enumerate(bl.block_inputs): - i_nr = len(value_idxes) - value_idxes[i] = i_nr - i_name = f"bi %{i_nr}" - if add_names: - i.name = i_name - v_color = "black" if i not in bl.block_outputs else "red" - sub_dot.node(f"v {i_nr}", label=i_name, color=v_color) - - for i_n, n in enumerate(bl.nodes): - label = n.i.opname - if n.i.opname == "CALL_METHOD": - assert n.inputs[0].name is not None - label = "CM " + n.inputs[0].name - elif n.i.opname == "CALL_FUNCTION" and n.inputs[0].name: - label = "CF " + n.inputs[0].name - sub_dot.node(f"i {i_bl} {i_n}", label, shape="box") - for o in n.outputs: - if o not in value_idxes: - o_nr = len(value_idxes) - value_idxes[o] = o_nr - o_name = o.name or f"%{o_nr}" - if add_names: - o.name = o_name - v_color = "black" if o not in bl.block_outputs else "red" - sub_dot.node(f"v {o_nr}", label=o_name, color=v_color) - else: - o_nr = value_idxes[o] - sub_dot.edge(f"i {i_bl} {i_n}", f"v {o_nr}", color="blue") - if i_n > 0: - sub_dot.edge(f"i {i_bl} {i_n - 1}", f"i {i_bl} {i_n}") - - for i_bl, bl in enumerate(gr.blocks): - for jt_bl in bl.nodes[-1].jump_targets: - dot.edge(f"i {i_bl} {len(bl.nodes) - 1}", f"i {block_idxes[jt_bl]} {0}") - for i in bl.block_inputs: - i_idx = value_idxes[i] - if isinstance(i, PhiValue): - for v in i.values: - if v in value_idxes: - dot.edge(f"v {value_idxes[v]}", f"v {i_idx}", color="green") - - for i_n, n in enumerate(bl.nodes): - for i in n.inputs: - if i in value_idxes: - dot.edge(f"v {value_idxes[i]}", f"i {i_bl} {i_n}", color="blue") - elif isinstance(i, PhiValue): - assert False, "This should be removed?" - for v in i.values: - if v in value_idxes: - dot.edge(f"v {value_idxes[v]}", f"i {i_bl} {i_n}", color="red") - - return dot - - -def clone_blocks( - blocks_to_clone: list[Block], translation_dict: dict[GraphObject, GraphObject] | None = None -) -> tuple[list[Block], dict[GraphObject, GraphObject]]: - if translation_dict is None: - translation_dict = {} - - blocks_todo = [] - for obl in blocks_to_clone: - if obl not in translation_dict: - bl = Block() - translation_dict[obl] = bl - blocks_todo.append(obl) - - for obl in blocks_todo: - bl = assert_block(translation_dict[obl]) - bl.block_inputs = [i.clone(translation_dict=translation_dict) for i in obl.block_inputs] - bl.block_outputs = OrderedSet(o.clone(translation_dict=translation_dict) for o in obl.block_outputs) - bl.nodes = [n.clone(translation_dict=translation_dict) for n in obl.nodes] - for obl in blocks_todo: - bl = assert_block(translation_dict[obl]) - for js in obl.jump_sources: - if js is None: - bl.jump_sources.append(None) - elif js in translation_dict: - bl.jump_sources.append(assert_node(translation_dict[js])) - - for i in bl.block_inputs: - i.post_process_clone(translation_dict=translation_dict) - return [assert_block(translation_dict[bl]) for bl in blocks_to_clone], translation_dict - - -def _check_graph(gr: Graph) -> None: - # some sanity checks for the values - import collections - - phi_value_refs: dict[PhiValue, list[Value | tuple[Value, Node | None]]] = collections.defaultdict(list) - v: Value - known_nodes: set[Node] = set() - for bl in gr.blocks: - known_values: set[Value] = set(bl.block_inputs) - for i in bl.block_inputs: - for v in i.phi_values: - phi_value_refs[v].append(i) - for n in bl.nodes: - known_nodes.add(n) - assert n.source_infos, f"{n}({n.inputs}) does not have source infos" - n.block = bl - for i in n.inputs: - i_or_p = i - while not (i_or_p in known_values or i_or_p.is_const or i_or_p.is_global): - if i_or_p.parent is not None: - i_or_p = i_or_p.parent - else: - raise RuntimeError(f"unknown value {repr(i_or_p)} needed in {n}") - - for o in n.outputs: - known_values.add(o) - # inplace modified values are not re-assigned. should they, likely: yes - if o not in n.inputs: - for v in o.phi_values: - phi_value_refs[v].append((o, n)) - for o in bl.block_outputs: - is_attr = False - o_or_parent = o - while o_or_parent not in known_values and o_or_parent.parent is not None: - o_or_parent = o_or_parent.parent - is_attr = True - if is_attr: - for v in o.phi_values: - phi_value_refs[v].append((o, None)) - assert ( - o_or_parent in known_values or o_or_parent.is_const or o_or_parent.is_global - ), f"{o_or_parent} (from {o}) unknown {known_values=}" - - for bl in gr.blocks: - for i in bl.block_inputs: - assert isinstance(i, PhiValue) - assert len(i.jump_sources) == len(i.values) - assert len(i.values) > 0 - # assert i.block is bl - pvr = phi_value_refs.get(i, []) - assert len([v for v in i.values if not (v.is_function_arg or v.is_const or v.is_global)]) == len( - pvr - ), f"phi value {repr(i)} source count {len(i.values)} does not match sets {pvr}, {i.values}" - if i in phi_value_refs: # not for function args in first block - del phi_value_refs[i] - for v in i.values: - assert i in v.phi_values, f"phi value {repr(i)} not in phi_values of {repr(v)}" - for js in i.jump_sources: - assert js is None or js in known_nodes, f"phi value {repr(i)} jump source not found in graph {repr(js)}" - - assert not phi_value_refs, f"phi_values not found {phi_value_refs}" - - jump_targets: dict[Node | None, set[Block]] = {} - jump_targets[None] = {gr.blocks[0]} # function entry point - - for bl in gr.blocks: - for n in bl.nodes[:-1]: - assert not n.jump_targets - n = bl.nodes[-1] - if n.i.opname in {"RETURN_VALUE", "RAISE_VARARGS", "RERAISE"}: - assert not n.jump_targets - else: - assert 1 <= len(n.jump_targets) <= 2, f"{n} should have one or two ump targets, but has {n.jump_targets}" - jump_targets[n] = {jt for jt in n.jump_targets} - assert len(n.jump_targets) == len(jump_targets[n]) - - for bl in gr.blocks: - for js in bl.jump_sources: - js_jt = jump_targets[js] - js_jt.remove(bl) - - assert not any(jump_targets.values()), f"{jump_targets} should be all empty" - assert tuple(gr.blocks[0].jump_sources) == (None,), gr.blocks[0].jump_sources - - -def repr_source_location(gr: Graph, source_infos: list[SourceInformation]): - l = [] - for si in source_infos: - l.append(f"file: {si.orig_file_name}, line {si.orig_line_no}:") - ls = linecache.getlines(si.orig_file_name) - l.append(ls[max(si.orig_line_no - 1, 0)].rstrip()) - return "\n".join(l) - - -def check_graph(gr: Graph) -> None: - try: - _check_graph(gr) - cloned, _ = gr.clone() - _check_graph(cloned) - except BaseException: - print() - gr.summary(print_lines=True) - raise - - -def _generate_raises(msg): - @noinline - def _raise(): - raise AttributeError(msg) - - return _raise diff --git a/thunder/core/script/instrumentation.py b/thunder/core/script/instrumentation.py deleted file mode 100644 index d7c7ef182b..0000000000 --- a/thunder/core/script/instrumentation.py +++ /dev/null @@ -1,145 +0,0 @@ -import contextlib -import functools -import inspect -import logging -import threading -import typing - -from thunder.core.utils import debug_asserts_enabled - - -T = typing.TypeVar("T") -_STORAGE = threading.local() - - -def _lookup_state(name: str, factory: typing.Callable[[], T]) -> T: - if not hasattr(_STORAGE, name): - setattr(_STORAGE, name, factory()) - return getattr(_STORAGE, name) - - -get_stack = functools.partial(_lookup_state, "stack", list) -get_init_ctx = functools.partial(_lookup_state, "init_ctx", dict) -get_error_ctx = functools.partial(_lookup_state, "error_ctx", list) -get_logger = functools.partial(_lookup_state, "logger", lambda: logging.error) - - -class InstrumentingBase: - def __new__(cls, *_, **__) -> "InstrumentingBase": - self = super().__new__(cls) - if stack := get_stack(): - get_init_ctx()[id(self)] = (self, tuple(stack)) - - return self - - def _concise_repr(self) -> str: - return f"<{self.__class__.__name__} object at {hex(id(self))}>" - - -def emit_ctx(v, follow_delegates: bool): - for f, args, kwargs, delegate_to in reversed(get_init_ctx()[id(v)][1]): - signature = inspect.signature(f) - bound = signature.bind(*args, **kwargs) - bound.apply_defaults() - - def fmt_arg(k, v): - if v is signature.parameters[k].default: - return "..." - - if isinstance(v, InstrumentingBase): - v_repr = v._concise_repr() - elif callable(v) and hasattr(v, "__name__"): - v_repr = v.__name__ - else: - v_repr = repr(v) - - return v_repr - - if delegate_to is None or not follow_delegates: - arg_str = ", ".join(fmt_arg(k, v) for k, v in bound.arguments.items()) - yield f" {f.__name__:<30} {arg_str}" - - else: - x = bound.arguments[delegate_to] - yield f" {f.__name__:<30} {fmt_arg(delegate_to, x)}" - yield from emit_ctx(x, follow_delegates) - break - - -def maybe_flush_errors(): - if not get_stack() and (error_ctx := get_error_ctx()): - get_logger()("\n".join(reversed(error_ctx)) + "\n") - error_ctx.clear() - - -@contextlib.contextmanager -def intercept_errors(): - prior_logger = get_logger() - errors = [] - try: - _STORAGE.logger = lambda s: errors.append(s) - yield errors - finally: - _STORAGE.logger = prior_logger - - -def verbose_error(f): - @functools.wraps(f) - def wrapped(*args, **kwargs): - if not debug_asserts_enabled(): - return f(*args, **kwargs) - - try: - return f(*args, **kwargs) - except BaseException as e: - bound = inspect.signature(f).bind(*args, **kwargs) - bound.apply_defaults() - - f_name = f"| f.__name__ = {f.__name__} |" - lines = [f"\n{'-' * len(f_name)}\n{f_name}\n{'-' * len(f_name)}\n"] - for k, v in bound.arguments.items(): - lines.append(f"Argument(`{k}`):\n {v}\n") - if id(v) in get_init_ctx(): - lines.extend( - [ - "Context (raw):", - *reversed(tuple(emit_ctx(v, follow_delegates=False))), - "\nContext (augmented):", - *reversed(tuple(emit_ctx(v, follow_delegates=True))), - ] - ) - - get_error_ctx().append("\n".join(lines)) - maybe_flush_errors() - raise - - return wrapped - - -def record(delegate_to: str | None | typing.Callable = None): - # Hack to allow you to to decorate with `@record` instead of `@record()`. - if callable(delegate_to): - return record()(delegate_to) - - def wrapper(f): - f_verbose = verbose_error(f) - - @functools.wraps(f) - def wrapped(*args, **kwargs): - if not debug_asserts_enabled(): - return f(*args, **kwargs) - - stack = get_stack() - try: - stack.append((f, args, kwargs, delegate_to)) - return f_verbose(*args, **kwargs) - - finally: - _ = stack.pop() - maybe_flush_errors() - if not stack: - get_init_ctx().clear() - - return wrapped - - return wrapper diff --git a/thunder/core/script/mypy-strict.ini b/thunder/core/script/mypy-strict.ini deleted file mode 100644 index 1637b7df92..0000000000 --- a/thunder/core/script/mypy-strict.ini +++ /dev/null @@ -1,52 +0,0 @@ -# Forked from PyTorch's mypy-strict.ini file. -# It enforces very strict typing rules. - -[mypy] -python_version = 3.10 - -cache_dir = .mypy_cache/strict -allow_redefinition = True -strict_optional = True -show_error_codes = True -show_column_numbers = True -warn_no_return = True -disallow_any_unimported = True - -# Across versions of mypy, the flags toggled by --strict vary. To ensure -# we have reproducible type check, we instead manually specify the flags -warn_unused_configs = True -disallow_any_generics = True -disallow_subclassing_any = True -disallow_untyped_calls = True -disallow_untyped_defs = True -disallow_incomplete_defs = True -check_untyped_defs = True -disallow_untyped_decorators = True -no_implicit_optional = True -warn_redundant_casts = True -warn_return_any = True -implicit_reexport = False -strict_equality = True - -# do not re-enable this: -# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657 -warn_unused_ignores = False - -files = - thunder/core/script/algorithms.py, - thunder/core/script/protograph.py, - thunder/core/script/protograph_passes.py, - thunder/core/script/parse, - thunder/core/script/values - -[mypy-thunder.core.utils] -follow_imports = silent - -[mypy-thunder] -follow_imports = skip - -[mypy-thunder.*] -follow_imports = skip - -[mypy-networkx] -ignore_missing_imports = True diff --git a/thunder/core/script/noinline.py b/thunder/core/script/noinline.py deleted file mode 100644 index 03dbee5e7b..0000000000 --- a/thunder/core/script/noinline.py +++ /dev/null @@ -1,38 +0,0 @@ -from contextvars import ContextVar -from collections.abc import Callable - - -NOINLINE_METHODS: ContextVar[set[Callable]] = ContextVar("NOINLINE_METHODS", default=set()) - - -def noinline(f: Callable) -> Callable: - """ - Function/Decorator to prevent preprocessing from inlining the function. - - Example: - >>> @noinline - >>> def foo(x): - >>> return x + 1 - >>> def bar(x): - >>> return foo(x) + 1 - >>> thunder.compile(bar) - """ - - NOINLINE_METHODS.get().add(f) - return f - - -@noinline -def invoke_noinline(f: Callable) -> Callable: - """ - Function to prevent preprocessing from inlining a single invocation of a function. - - Example: - >>> def foo(x): - >>> return x + 1 - >>> def bar(x): - >>> return invoke_noinline(foo)(x) + 1 - >>> thunder.compile(bar) - """ - - return f diff --git a/thunder/core/script/overview.ipynb b/thunder/core/script/overview.ipynb deleted file mode 100644 index 152d98a6e2..0000000000 --- a/thunder/core/script/overview.ipynb +++ /dev/null @@ -1,331 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "5c236b7e-9191-4ead-8d76-cd789c09f810", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "thunder_path = os.path.abspath(os.path.join(os.path.abspath(\"\"), \"..\", \"..\", \"..\"))\n", - "if thunder_path not in sys.path:\n", - " sys.path.append(thunder_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "332a649c-cf11-4803-ae91-fff31d31d359", - "metadata": {}, - "outputs": [], - "source": [ - "import thunder" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "02373250-c7b0-446b-ad21-6e847d58feec", - "metadata": {}, - "outputs": [], - "source": [ - "def masked_apply(x, mask, layer_0, layer_1):\n", - " x = layer_0(x, mask)\n", - " x = layer_1(x, mask if mask is not None else 1)\n", - " return x, mask is None" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9c345326-89ae-4efa-9657-545ae3eeb215", - "metadata": {}, - "outputs": [], - "source": [ - "from thunder.core.script.frontend import _construct_protograph, acquire_method\n", - "proto_graph = _construct_protograph(masked_apply)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "2205c64a-caa8-4913-81ce-a43b10ffb1d1", - "metadata": {}, - "outputs": [], - "source": [ - "from thunder.core.script import parse, values\n", - "\n", - "provenance = []\n", - "last = proto_graph.provenance\n", - "while isinstance(last, tuple):\n", - " provenance.append(last)\n", - " transform, prior_proto_graph = last\n", - " last = prior_proto_graph.provenance\n", - "assert isinstance(last, values.ParsedSymbolic)\n", - "provenance.extend((last, last.provenance, last.provenance.provenance))\n", - "provenance.reverse()\n", - "\n", - "disassembled, parsed, parsed_symbolic, *protograph_transforms = provenance\n", - "assert isinstance(disassembled, parse.Disassembled)\n", - "assert isinstance(parsed, parse.ParsedFunctional)\n", - "assert isinstance(parsed_symbolic, values.ParsedSymbolic)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "b0178f36-4774-483b-8d37-e2271c42f2de", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "\n", - "\n", - " 2 0 LOAD_FAST 2 (layer_0) || 0) LOAD_FAST\n", - " 2 LOAD_FAST 0 (x) || LOAD_FAST\n", - " 4 LOAD_FAST 1 (mask) || LOAD_FAST\n", - " 6 CALL_FUNCTION 2 || CALL_FUNCTION\n", - " 8 STORE_FAST 0 (x) || STORE_FAST\n", - " || LOAD_FAST\n", - " 3 10 LOAD_FAST 3 (layer_1) || LOAD_FAST\n", - " 12 LOAD_FAST 0 (x) || LOAD_FAST\n", - " 14 LOAD_FAST 1 (mask) || LOAD_CONST\n", - " 16 LOAD_CONST 0 (None) || IS_OP\n", - " 18 IS_OP 1 || POP_JUMP_IF_FALSE\n", - " 20 POP_JUMP_IF_FALSE 13 (to 26) || \n", - " 22 LOAD_FAST 1 (mask) || 1) LOAD_FAST\n", - " 24 JUMP_FORWARD 1 (to 28) || JUMP_FORWARD\n", - " >> 26 LOAD_CONST 1 (1) || \n", - " >> 28 CALL_FUNCTION 2 || 2) LOAD_CONST\n", - " 30 STORE_FAST 0 (x) || JUMP_ABSOLUTE\n", - " || \n", - " 4 32 LOAD_FAST 0 (x) || 3) CALL_FUNCTION\n", - " 34 LOAD_FAST 1 (mask) || STORE_FAST\n", - " 36 LOAD_CONST 0 (None) || LOAD_FAST\n", - " 38 IS_OP 0 || LOAD_FAST\n", - " 40 BUILD_TUPLE 2 || LOAD_CONST\n", - " 42 RETURN_VALUE || IS_OP\n", - " || BUILD_TUPLE\n", - " || RETURN_VALUE\n", - " || \n" - ] - } - ], - "source": [ - "import dis\n", - "import io\n", - "import itertools\n", - "\n", - "print(disassembled.code, \"\\n\\n\")\n", - "dis.dis(disassembled.code, file=(buffer := io.StringIO()))\n", - "buffer.seek(0)\n", - "dis_lines = buffer.read().splitlines(False)\n", - "\n", - "block_lines = []\n", - "for idx, block in enumerate(disassembled.raw.values()):\n", - " block_lines.extend(f\"{'' if idy else f'{idx})':<4}{instruction.opname}\" for idy, instruction in enumerate(block))\n", - " block_lines.append(\"\")\n", - "\n", - "pad = max(len(l) for l in dis_lines)\n", - "for dis_line, block_line in itertools.zip_longest(dis_lines, block_lines, fillvalue=\"\"):\n", - " print(f\"{dis_line:<{pad + 10}} ||{' ' * 10}{block_line}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "dec0f942-0915-4cf5-a79a-9e37fd2f4e6f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Block 0: [] => [layer_1, v0]\n", - " LOAD[layer_0, x, mask]\n", - " CALL_FUNCTION . . . . . . . . . . . . . . (layer_0, x, mask) -> v0\n", - " STORE[x]\n", - " LOAD[layer_1, x, mask, None: CONST]\n", - " IS_OP . . . . . . . . . . . . . . . . . . (mask, None) -> v1\n", - " POP_JUMP_IF_FALSE . . . . . . . . . . . . (v1) -> \n", - " -> 1, 2(Jump)\n", - "\n", - "Block 1: [⓵ , ⓶ ] => [⓵ , ⓶ , mask]\n", - " LOAD[mask]\n", - " JUMP_FORWARD\n", - " -> 3(Jump)\n", - "\n", - "Block 2: [⓵ , ⓶ ] => [⓵ , ⓶ , 1]\n", - " LOAD[1: CONST]\n", - " JUMP_ABSOLUTE*\n", - " -> 3(Jump)\n", - "\n", - "Block 3: [⓵ , ⓶ , ⓷ ] => []\n", - " CALL_FUNCTION . . . . . . . . . . . . . . (⓵ , ⓶ , ⓷ ) -> v0\n", - " STORE[x]\n", - " LOAD[x, mask, None: CONST]\n", - " IS_OP . . . . . . . . . . . . . . . . . . (mask, None) -> v1\n", - " BUILD_TUPLE . . . . . . . . . . . . . . . (v0, v1) -> v2\n", - " RETURN_VALUE . . . . . . . . . . . . . . (v2) -> \n", - "\n" - ] - } - ], - "source": [ - "print(parsed.summary)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "3a1cdbfe-996f-48c2-a5d4-b05389d2316d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CALL_FUNCTION (layer_0(LOCAL), x(LOCAL), mask(LOCAL)) -> (IntermediateValue(at 0x7f064f0c26b0),)\n", - "IS_OP (mask(LOCAL), None(CONST)) -> (IntermediateValue(at 0x7f064f0c3070),)\n", - "POP_JUMP_IF_FALSE (OutputRef(IS_OP, idx=0)) -> ()\n", - "\n", - "JUMP_FORWARD () -> ()\n", - "\n", - "JUMP_ABSOLUTE () -> ()\n", - "\n", - "CALL_FUNCTION (0(STACK), 1(STACK), 2(STACK)) -> (IntermediateValue(at 0x7f064f0f0040),)\n", - "IS_OP (mask(LOCAL), None(CONST)) -> (IntermediateValue(at 0x7f064f0f0250),)\n", - "BUILD_TUPLE (OutputRef(CALL_FUNCTION, idx=0), OutputRef(IS_OP, idx=0)) -> (IntermediateValue(at 0x7f064f0f0460),)\n", - "RETURN_VALUE (OutputRef(BUILD_TUPLE, idx=0)) -> ()\n", - "\n" - ] - } - ], - "source": [ - "def pretty_repr(x) -> str:\n", - " if isinstance(x, parse.VariableKey):\n", - " return f\"{x.identifier}({x.scope.name})\"\n", - " if isinstance(x, values.OutputRef):\n", - " return f\"OutputRef({x.instruction.opname}, idx={x.idx})\"\n", - " return repr(x)\n", - "\n", - "for block, begin, end in parsed_symbolic.blocks:\n", - " # At this point `begin` isn't very interesting as it's all just placeholders.\n", - " for instruction, symbolic in block.items():\n", - " \n", - " inputs = \", \".join(pretty_repr(i) for i in symbolic.inputs.ordered)\n", - " print(f\"{instruction.opname:<25} ({inputs}) -> {symbolic.outputs}\")\n", - " print()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "8daca4b6-31b3-4d8e-8be2-9744b44d8f4f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Unlink\n", - "MarkTuples\n", - "AddTransitive\n", - "Connect\n", - "\n", - "================================================================================\n", - " ProtoBlock: 0x7f064f0f1cf0\n", - " CALL_FUNCTION\n", - " IS_OP\n", - " POP_JUMP_IF_FALSE\n", - "\n", - "ProtoBlock: 0x7f064ef1e560\n", - " JUMP_FORWARD\n", - "\n", - "ProtoBlock: 0x7f064ef81b10\n", - " JUMP_ABSOLUTE\n", - "\n", - "ProtoBlock: 0x7f064ef9cfd0\n", - " CALL_FUNCTION\n", - " IS_OP\n", - " BUILD_TUPLE\n", - " RETURN_VALUE\n" - ] - } - ], - "source": [ - "for transform, _ in protograph_transforms:\n", - " print(transform.__name__)\n", - "\n", - "print(f\"\\n{'=' * 80}\\n\", proto_graph)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "5c56fbcf-13db-4749-940b-ecbb0dfd0af5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Graph of\n", - " Block (reached from [None])\n", - " CALL_FUNCTION([, , ])\n", - " IS_OP 1 (1)\n", - " POP_JUMP_IF_FALSE 13 (26)\n", - " Block (reached from [])\n", - " JUMP_FORWARD 1 (28)\n", - " Block (reached from [])\n", - " JUMP_ABSOLUTE 14 (None)\n", - " Block (reached from [, ])\n", - " CALL_FUNCTION([, , ])\n", - " IS_OP 0 (0)\n", - " BUILD_TUPLE 2 (2)\n", - " RETURN_VALUE None (None)\n" - ] - } - ], - "source": [ - "g = acquire_method(masked_apply)\n", - "print(g)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22ef39c9-04aa-40ff-acc2-1387a3f372f4", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/thunder/core/script/parse/__init__.py b/thunder/core/script/parse/__init__.py deleted file mode 100644 index c29adb91a6..0000000000 --- a/thunder/core/script/parse/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from thunder.core.script.parse.disassemble import * -from thunder.core.script.parse.functionalize import * -from thunder.core.script.parse.instructions import * -from thunder.core.script.parse.stack_effect import * - -# This will be populated as parse-time narrowing is introduced. -FORBIDDEN_INSTRUCTIONS = InstructionSet() diff --git a/thunder/core/script/parse/disassemble.py b/thunder/core/script/parse/disassemble.py deleted file mode 100644 index 3ae5ae1bbe..0000000000 --- a/thunder/core/script/parse/disassemble.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Convert a `CodeType` object into a series of simple blocks.""" -from __future__ import annotations - -import dataclasses -import dis -import itertools -from types import CodeType -from typing import Any, NewType, TypeVar -from collections.abc import Iterable, Mapping - -from thunder.core.script.parse import stack_effect -from thunder.core.script.parse.instructions import * # There are a lot of constants, and it defines `__all__` - - -__all__ = ("Disassembled", "ParseDetailInstruction", "EpilogueFixup", "Jump") - -BlockIdx = NewType("BlockIdx", int) -Jump = NewType("Jump", bool) - - -@dataclasses.dataclass -class Disassembled: - code: CodeType - - _StartIndex = NewType("_StartIndex", int) - _RawBlocks = dict[_StartIndex, tuple[ThunderInstruction, ...]] - raw: RawBlocks - - _Blocks = tuple[tuple["ThunderInstruction", ...], ...] - blocks: _Blocks - - _Edges = tuple[tuple[BlockIdx, BlockIdx, Jump], ...] - edges: _Edges - - @classmethod - def make(cls, co: CodeType) -> Disassembled: - raw_blocks, last_line_no = partition(co) - blocks, edges = connect_blocks(consolidate_returns(raw_blocks)) - for instruction in itertools.chain(*blocks): - instruction.line_no = getattr(instruction, "line_no", last_line_no) - return cls(code=co, raw=raw_blocks, blocks=blocks, edges=edges) - - -class ParseDetailInstruction(ThunderInstruction): - """Allow us to distinguish instructions that are added during parsing.""" - - pass - - -class EpilogueFixup(ParseDetailInstruction): - pass - - -def compute_jump(instruction: ThunderInstruction, position: int) -> int | None: - if instruction in ABSOLUTE_JUMP_INSTRUCTIONS: - return instruction.oparg - - elif instruction in UNCONDITIONAL_BACKWARD: - return position + 1 - instruction.oparg - - elif "BACKWARD" in instruction.opname: - # TODO: POP_JUMP_BACKWARD_IF_... variants - raise NotImplementedError(instruction.opname) - - elif instruction in RELATIVE_JUMP_INSTRUCTIONS: - return position + 1 + instruction.oparg - - return None - - -IntT = TypeVar("IntT", bound=int, covariant=True) -StartIndex = NewType("StartIndex", int) -RawBlocks = dict[StartIndex, tuple[ThunderInstruction, ...]] - - -def get_free_key(x: Mapping[IntT, Any]) -> int: - key = -len(x) - while key in x: - key -= 1 - return key - - -def partition(co: CodeType) -> tuple[RawBlocks, int]: - bytecode = tuple(ThunderInstruction(*i) for i in dis.get_instructions(co, first_line=0)) - - # Determine the boundaries for the simple blocks. - split_after = JUMP_INSTRUCTIONS | RETURN_INSTRUCTIONS - follows_jump = itertools.chain([0], (int(i in split_after) for i in bytecode)) - new_block = (int(i or j.is_jump_target) for i, j in zip(follows_jump, bytecode)) - - # Split the bytecode (and instruction number) into groups - group_indices = tuple(itertools.accumulate(new_block)) - groups = itertools.groupby(enumerate(bytecode), key=lambda args: group_indices[args[0]]) - - # Drop the group index, copy from the groupby iter, and unzip `enumerate`. - groups = (zip(*tuple(i)) for _, i in groups) - blocks: dict[StartIndex, list[ThunderInstruction]] = { - StartIndex(start): list(block) for (start, *_), block in groups - } - - # If the last instruction is not a jump or return (which means we split - # because the next instruction was a jump target) then we need to tell - # the current block how to advance. - for start, block in blocks.items(): - if block[-1] not in split_after: - next_start = StartIndex(start + len(block)) - assert bytecode[next_start].is_jump_target - block.append(ParseDetailInstruction.make_jump_absolute(next_start)) - - line_no = 1 - for instruction in itertools.chain(*[block for block in blocks.values()]): - instruction.line_no = line_no = instruction.starts_line or line_no - - return {k: tuple(v) for k, v in blocks.items()}, line_no - - -def consolidate_returns(blocks: RawBlocks) -> RawBlocks: - def is_return(block: tuple[ThunderInstruction, ...]) -> bool: - assert block and not any(i.opname == RETURN_VALUE for i in block[:-1]) - return block[-1].opname == RETURN_VALUE - - blocks = blocks.copy() - return_blocks = {k: v for k, v in blocks.items() if is_return(v)} - if len(return_blocks) > 1: - new_return_start = StartIndex(get_free_key(blocks)) - for start, (*body, prior_return) in return_blocks.items(): - assert is_return((prior_return,)), prior_return - blocks[start] = (*body, ParseDetailInstruction.make_jump_absolute(new_return_start)) - return_blocks = {new_return_start: (ParseDetailInstruction.make_return(is_jump_target=True),)} - - # Move return block to the end. This isn't always valid (since a block might - # expect to fall through and reach it), but that will be resolved by the - # sort in `ProtoGraph`'s ctor. - blocks = {k: v for k, v in blocks.items() if k not in return_blocks} - blocks.update(return_blocks) - return blocks - - -def connect_blocks(blocks: RawBlocks) -> tuple[Disassembled._Blocks, Disassembled._Edges]: - def iter_raw_edges(blocks: RawBlocks) -> Iterable[tuple[StartIndex, StartIndex, Jump, int, int]]: - for start, block in tuple(blocks.items()): - raw_block_len = sum(not isinstance(i, ParseDetailInstruction) for i in block) - *_, last_i = block - if last_i in JUMP_INSTRUCTIONS: - end = start + raw_block_len - 1 - _, (push_nojump, push_jump) = stack_effect.stack_effect_detail(last_i) - if last_i not in UNCONDITIONAL_JUMP_INSTRUCTIONS: - yield start, StartIndex(end + 1), Jump(False), max(push_jump - push_nojump, 0), last_i.line_no - - if (jump_offset := compute_jump(last_i, end)) is not None: - yield start, StartIndex(jump_offset), Jump(True), max(push_nojump - push_jump, 0), last_i.line_no - - blocks = blocks.copy() - edges: list[tuple[StartIndex, StartIndex, Jump]] = [] - for source, destination, jump, pop_suffix, line_no in iter_raw_edges(blocks): - if pop_suffix: - blocks[epilogue := StartIndex(get_free_key(blocks))] = ( - *(EpilogueFixup.make(POP_TOP, None, line_no=line_no) for _ in range(pop_suffix)), - EpilogueFixup.make_jump_absolute(destination, line_no=line_no), - ) - edges.extend(((source, epilogue, jump), (epilogue, destination, jump))) - else: - edges.append((source, destination, jump)) - - to_idx = {k: BlockIdx(idx) for idx, k in enumerate(blocks.keys())} - return tuple(blocks.values()), tuple((to_idx[source], to_idx[sink], jump) for source, sink, jump in edges) diff --git a/thunder/core/script/parse/functionalize.py b/thunder/core/script/parse/functionalize.py deleted file mode 100644 index 9664271918..0000000000 --- a/thunder/core/script/parse/functionalize.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Replay the CPython stack machine to determine data flow within a simple block.""" -from __future__ import annotations - -import dataclasses -import enum -import inspect -import itertools -import marshal -import textwrap -from types import CodeType -from typing import Any, NamedTuple, NewType - -import networkx as nx - -from thunder.core.script import algorithms -from thunder.core.script.parse import disassemble, instructions, stack_effect -from thunder.core.utils import safe_zip, FrozenDict, InferringDict - -__all__ = ("VariableScope", "VariableKey", "ParsedFunctional", "FunctionalizedBlock", "PlaceholderValue") - - -class VariableScope(enum.Enum): - CONST = enum.auto() - LOCAL = enum.auto() - NONLOCAL = enum.auto() - GLOBAL = enum.auto() - STACK = enum.auto() - - -class VariableKey(NamedTuple): - """Denotes the location of a variable. - For example, `x = 5` assigns the variable stored in `VariableKey(5, VariableScope.CONST)` - to the location `VariableKey("x", VariableScope.LOCAL)`. (Provided `x` is a local variable.) - The type of `identifier` varies based on `scope`: - `marshal`able VariableScope.CONST - str VariableScope.LOCAL / NONLOCAL / GLOBAL - int VariableScope.STACK - Any VariableScope.BOUNDARY - """ - - identifier: Any - scope: VariableScope - - def __repr__(self) -> str: - return f"VariableKey({self.identifier}, scope={self.scope.name})" - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, VariableKey) - and self.scope == other.scope - and type(self.identifier) is (_ := type(other.identifier)) # Conflict between `ruff` and `yesqa` - and self.identifier == other.identifier - ) - - def __lt__(self, other: tuple[Any, ...]) -> bool: - assert isinstance(other, VariableKey), (self, other) - try: - return (self.scope.value, self.identifier) < (other.scope.value, other.identifier) - except TypeError: - assert self.scope == other.scope, (self, other) - if self.scope == VariableScope.CONST: - # We prefer to use native ordering. However for unorderable types (e.g. CodeType) - # `marshal` at least provides a consistent ordering. - return marshal.dumps(self.identifier) < marshal.dumps(other.identifier) - raise - - @property - def is_const(self) -> bool: - return self.scope == VariableScope.CONST - - -def _compute_stack_offsets(disassembled: disassemble.Disassembled) -> tuple[int, ...]: - # If we convert the stack indices to a common basis then we can ignore stack effect and - # treat VariableScope.STACK variables just like any other local. - G = algorithms.TypedDiGraph[disassemble.BlockIdx]((i, j) for i, j, _ in disassembled.edges) - G.add_nodes_from(range(len(disassembled.blocks))) - offsets: dict[disassemble.BlockIdx, int] = {i: 0 for i in G.nodes if not G.in_degree(i)} # type: ignore[misc] - assert len(offsets) == 1, G - - for source, sink in nx.edge_dfs(G): - net_stack_effect = 0 - for instruction in disassembled.blocks[source]: - pop, push_by_branch = stack_effect.stack_effect_detail(instruction) - net_stack_effect += max(push_by_branch) - pop - expected = offsets[source] + net_stack_effect - actual = offsets.setdefault(sink, expected) - assert actual == expected, (actual, expected) - - assert all(v >= 0 for v in offsets.values()), offsets - return tuple(offsets[disassemble.BlockIdx(i)] for i in range(len(disassembled.blocks))) - - -LOAD_OPNAMES = FrozenDict[str, VariableScope]( - LOAD_CONST=VariableScope.CONST, - LOAD_FAST=VariableScope.LOCAL, - LOAD_DEREF=VariableScope.NONLOCAL, - LOAD_CLOSURE=VariableScope.NONLOCAL, - LOAD_GLOBAL=VariableScope.GLOBAL, -) - -STORE_OPNAMES = FrozenDict[str, VariableScope]( - STORE_FAST=VariableScope.LOCAL, - STORE_DEREF=VariableScope.NONLOCAL, - STORE_GLOBAL=VariableScope.GLOBAL, -) - -DEL_OPNAMES = FrozenDict[str, VariableScope]( - DELETE_FAST=VariableScope.LOCAL, - DELETE_DEREF=VariableScope.NONLOCAL, - DELETE_GLOBAL=VariableScope.GLOBAL, -) - -PlaceholderValue = NewType("PlaceholderValue", str) -Inputs = NewType("Inputs", tuple[PlaceholderValue, ...]) -Outputs = NewType("Outputs", tuple[PlaceholderValue, ...]) -BeginState = FrozenDict["VariableKey", PlaceholderValue] -EndState = FrozenDict["VariableKey", PlaceholderValue | None] -FunctionalNode = tuple[instructions.ThunderInstruction, Inputs, Outputs] -FunctionalizedBlock = NewType("FunctionalizedBlock", tuple[tuple[FunctionalNode, ...], BeginState, EndState]) - - -@dataclasses.dataclass(frozen=True) -class ParsedFunctional: - blocks: tuple[FunctionalizedBlock, ...] - provenance: disassemble.Disassembled - - @staticmethod - def make(co: CodeType) -> ParsedFunctional: - disassembled = disassemble.Disassembled.make(co) - return ParsedFunctional(_functionalize_blocks(disassembled), disassembled) - - @property - def summary(self) -> str: - return _summarize(self) - - -def _functionalize_blocks(disassembled: disassemble.Disassembled) -> tuple[FunctionalizedBlock, ...]: - code = disassembled.code - errors: list[str] = [] - if code.co_cellvars: - errors.append( - "Nonlocal variables are not supported but\n" - f" {code.co_name}() defined in {code.co_filename}:{code.co_firstlineno}\n" - f" defines nonlocal variable{'s' if len(code.co_cellvars) > 1 else ''}: {', '.join(code.co_cellvars)}" - ) - - def report_unsupported(msg: str, instruction: instructions.ThunderInstruction) -> None: - source_lines, _ = inspect.getsourcelines(code) - errors.append( - f"{msg}{instruction} found\n" - f" {code.co_name}() defined in {code.co_filename}:{code.co_firstlineno}\n" - f" line {instruction.line_no + code.co_firstlineno}: {source_lines[instruction.line_no].rstrip()}" - ) - - name_arrays = FrozenDict[VariableScope, tuple[str, ...]]( - { - VariableScope.CONST: code.co_consts, - VariableScope.LOCAL: code.co_varnames, - VariableScope.NONLOCAL: (*code.co_cellvars, *code.co_freevars), - VariableScope.GLOBAL: code.co_names, - } - ) - - def to_key(instruction: instructions.ThunderInstruction, scope: VariableScope) -> VariableKey: - assert scope != VariableScope.STACK, "Indexing into the stack is not permitted." - assert scope in name_arrays, f"Unknown variable scope: {scope}" - if scope == VariableScope.NONLOCAL: - report_unsupported("nonlocal variables are not supported but instruction = ", instruction) - return VariableKey(name_arrays[scope][instruction.oparg], scope) - - def convert(block: tuple[instructions.ThunderInstruction, ...], stack_offset: int) -> FunctionalizedBlock: - stack: list[PlaceholderValue] = [PlaceholderValue(f"Initial_stack_{i}") for i in range(stack_offset)] - begin_variables = {VariableKey(idx, VariableScope.STACK): v for idx, v in enumerate(stack)} - end_variables = InferringDict[VariableKey, PlaceholderValue | None]( - lambda key: begin_variables.setdefault(key, PlaceholderValue(f"Initial: ({key.identifier} {key.scope})")) - ) - - assert block - functionalized: list[tuple[instructions.ThunderInstruction, Inputs, Outputs]] = [] - for idx, instruction in enumerate(block): - # These are already reflected in the next opcode's argument - if instruction.opname == instructions.EXTENDED_ARG: - continue - - elif instruction in instructions.UNSAFE_OPCODES: - # These are unsafe to run, but we should still be able to parse them. - report_unsupported("Unsupported instruction = ", instruction) - - pop, push_by_branch = stack_effect.stack_effect_detail(instruction) - push = max(push_by_branch) - - def assert_expected_stack_effects(pop_i: int, push_i: int) -> None: - assert (pop, push) == (pop_i, push_i), f"{instruction=} {pop=} {push=}" - - # Peek at the stack to track variable mutations. - if (store_scope := STORE_OPNAMES.get(instruction.opname)) is not None: - assert_expected_stack_effects(1, 0) - end_variables[to_key(instruction, store_scope)] = stack.pop() - - elif (del_scope := DEL_OPNAMES.get(instruction.opname)) is not None: - assert_expected_stack_effects(1, 0) - end_variables[to_key(instruction, del_scope)] = None - - elif (load_scope := LOAD_OPNAMES.get(instruction.opname)) is not None: - assert_expected_stack_effects(0, 1) - loaded = end_variables[load_key := to_key(instruction, load_scope)] - assert loaded is not None, f"Access to deleted variable: {load_key}, {instruction}" - stack.append(loaded) - - else: - # We have already functionalized variable accesses, so we can prune loads and stores. - inputs = tuple(stack.pop() for _ in range(pop)) - outputs = Outputs(tuple(PlaceholderValue(f"{idx}_{instruction.opname}__{idy}") for idy in range(push))) - stack.extend(outputs) - functionalized.append((instruction, Inputs(tuple(reversed(inputs))), outputs)) - - end_stack = {VariableKey(idx, VariableScope.STACK): v for idx, v in enumerate(stack)} - end_state: EndState = FrozenDict({**end_variables, **end_stack}) - return FunctionalizedBlock((tuple(functionalized), FrozenDict(begin_variables), end_state)) - - stack_offsets = _compute_stack_offsets(disassembled) - functionalized = tuple(convert(block, offset) for block, offset in safe_zip(disassembled.blocks, stack_offsets)) - if errors: - raise RuntimeError("Preprocessing issues detected:\n" + textwrap.indent("\n\n".join(errors), " " * 4)) - - return functionalized - - -# ============================================================================= -# == Summary for debugging and testing ======================================== -# ============================================================================= -def _summarize(parsed: ParsedFunctional) -> str: - # Clear identifiers for input stack values. - to_symbol = FrozenDict[int, str](enumerate("⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ Ⓐ Ⓑ Ⓒ Ⓓ Ⓔ Ⓕ".split())) - - # Group output edges. - grouped_edges: dict[int, str] = { - source: ", ".join(f"{sink}{'(Jump)' if jump else ''}" for _, sink, jump in sinks) - for source, sinks in itertools.groupby(parsed.provenance.edges, lambda e: e[0]) - } - - # Best effort to apply descriptive names. - inputs_outputs = {} - block_headers: list[str] = [] - for idx, (functionalized_block, begin, end) in enumerate(parsed.blocks): - begin_stack = tuple(v for k, v in begin.items() if k.scope == VariableScope.STACK) - stack_names: dict[str, str] = {v: to_symbol.get(idx, f"S{idx}") + "\u2009" for idx, v in enumerate(begin_stack)} - names: dict[str, str] = {**{v: f"{k.identifier}" for k, v in begin.items()}, **stack_names} - names.update({v: f"v{idx}" for idx, v in enumerate(itertools.chain(*[o for _, _, o in functionalized_block]))}) - for instruction, inputs, outputs in functionalized_block: - inputs_outputs[instruction] = (tuple(names[i] for i in inputs), tuple(names[o] for o in outputs)) - - end_stack = {k: v for k, v in end.items() if k.scope == VariableScope.STACK} - assert tuple(k.identifier for k in end_stack) == tuple(range(len(end_stack))), end_stack - end_stack_str = ", ".join(names[i or ""] for i in end_stack.values()) - block_headers.append(f"Block {idx}: [{', '.join(stack_names.values())}] => [{end_stack_str}]") - - # Group loads and stores. - prefix = {opname: opname.split("_")[0] for opname in itertools.chain(STORE_OPNAMES, LOAD_OPNAMES, DEL_OPNAMES)} - condensed: list[list[tuple[str, instructions.ThunderInstruction | None]]] = [] - for raw_block in parsed.provenance.blocks: - condensed.append([]) - for prefix_or_i, group in itertools.groupby(raw_block, lambda i: prefix.get(i.opname, i)): - if isinstance(prefix_or_i, str): - name = ", ".join(f"{i.argval}: {i.opname[len(prefix_or_i) + 1:]}" for i in group) - condensed[-1].append((f"{prefix_or_i}[{name.replace(': FAST', '')}]", None)) - else: - opname = f"{prefix_or_i.opname}{'' if type(prefix_or_i) is instructions.ThunderInstruction else '*'}" - condensed[-1].append((opname, prefix_or_i)) - - # Write lines. - lines: list[str] = [] - width = max(len(name) for name, _ in itertools.chain(*condensed)) - width = max(width, max(len(i) for i in block_headers)) + 5 - for idx, (condensed_block, (_, _, end)) in enumerate(safe_zip(condensed, parsed.blocks)): - lines.append(block_headers[idx]) - for name, maybe_i in condensed_block: - inputs, outputs = inputs_outputs.get(maybe_i, ((), ())) # type: ignore[assignment, arg-type] - if inputs or outputs: - name = f"{name} ".ljust(width, ".").replace("..", ". ") - name = f"{name} ({', '.join(inputs)}) -> {', '.join(outputs)}" - lines.append(f" {name}") - if idx in grouped_edges: - lines.append(f" -> {grouped_edges[idx]}") - lines.append("") - - return "\n".join(lines) diff --git a/thunder/core/script/parse/instructions.py b/thunder/core/script/parse/instructions.py deleted file mode 100644 index 34fb7fed07..0000000000 --- a/thunder/core/script/parse/instructions.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Extension of the builtin `dis` module.""" -from __future__ import annotations - -import dis -from typing import Any - -from typing_extensions import Self - -from thunder.core.utils import _OrderedSet - -__all__ = ( - "ThunderInstruction", - "InstructionSet", - "JUMP_ABSOLUTE", - "RETURN_VALUE", - "POP_TOP", - "EXTENDED_ARG", - "UNCONDITIONAL_BACKWARD", - "UNCONDITIONAL_JUMP_INSTRUCTIONS", - "ABSOLUTE_JUMP_INSTRUCTIONS", - "RELATIVE_JUMP_INSTRUCTIONS", - "JUMP_INSTRUCTIONS", - "RAISE_RETURN_INSTRUCTIONS", - "RETURN_INSTRUCTIONS", - "UNSAFE_OPCODES", -) - - -class ThunderInstruction(dis.Instruction): - """Thin wrapper on top of dis.Instruction to implement thunder specific logic.""" - - line_no: int - - def __hash__(self) -> int: - # We sometimes want to use an instruction as a key so we can map back to nodes. - # `dis.Instruction` is a named tuple and therefore implements recursive constituent - # hashing which can lead to unwanted collisions. We instead override this behavior - # to instead use identity hashing. - return id(self) - - def __eq__(self, other: object) -> bool: - return self is other - - @property - def oparg(self) -> int: - assert self.arg is not None, self - return self.arg - - def modify_copy(self, **kwargs: Any) -> ThunderInstruction: - assert type(self) is ThunderInstruction, self - if "opname" in kwargs: - kwargs.setdefault("opcode", dis.opmap.get(kwargs["opname"], -1)) - result = ThunderInstruction(**{**self._asdict(), **kwargs}) - result.line_no = self.line_no - return result - - @classmethod - def make(cls, opname: str, arg: int | None, line_no: int, **kwargs: Any) -> Self: - ctor_kwargs = dict( - opname=opname, - opcode=dis.opmap.get(opname, -1), - arg=arg, - argval=None, - argrepr="None", - offset=-999, - starts_line=None, - is_jump_target=False, - ) - ctor_kwargs.update(kwargs) - result = cls(**ctor_kwargs) # type: ignore - result.line_no = line_no - return result - - @classmethod - def make_jump_absolute(cls, arg: int, line_no: int = -1) -> ThunderInstruction: - return cls.make(JUMP_ABSOLUTE, arg, argrepr=f"{arg}", line_no=line_no) - - @classmethod - def make_return(cls, is_jump_target: bool, line_no: int = -1) -> ThunderInstruction: - return cls.make(RETURN_VALUE, arg=None, argrepr="", is_jump_target=is_jump_target, line_no=line_no) - - -class InstructionSet(_OrderedSet[str, int | ThunderInstruction]): - """Convenience class for checking opcode properties.""" - - def canonicalize(self, i: str | int | ThunderInstruction) -> str: - if isinstance(i, str): - return i - - elif isinstance(i, int): - return dis.opname[i] - - else: - assert isinstance(i, ThunderInstruction) - return i.opname - - -# Special opcodes -JUMP_ABSOLUTE = "JUMP_ABSOLUTE" -RETURN_VALUE = "RETURN_VALUE" -POP_TOP = "POP_TOP" -EXTENDED_ARG = "EXTENDED_ARG" - - -UNCONDITIONAL_BACKWARD = InstructionSet(("JUMP_BACKWARD", "JUMP_BACKWARD_NO_INTERRUPT")) -UNCONDITIONAL_JUMP_INSTRUCTIONS = InstructionSet((JUMP_ABSOLUTE, "JUMP_FORWARD", *UNCONDITIONAL_BACKWARD)) - -ABSOLUTE_JUMP_INSTRUCTIONS = InstructionSet(dis.hasjabs) -RELATIVE_JUMP_INSTRUCTIONS = InstructionSet(dis.hasjrel) -JUMP_INSTRUCTIONS = InstructionSet((*dis.hasjabs, *dis.hasjrel, *UNCONDITIONAL_JUMP_INSTRUCTIONS)) - -RAISE_RETURN_INSTRUCTIONS = InstructionSet(("RAISE_VARARGS", "RERAISE")) -RETURN_INSTRUCTIONS = InstructionSet((RETURN_VALUE, *RAISE_RETURN_INSTRUCTIONS)) - - -# https://github.com/Lightning-AI/lightning-thunder/issues/1075 -UNSAFE_OPCODES = InstructionSet(("SETUP_WITH", "SETUP_FINALLY")) diff --git a/thunder/core/script/parse/stack_effect.py b/thunder/core/script/parse/stack_effect.py deleted file mode 100644 index 1ddeff412d..0000000000 --- a/thunder/core/script/parse/stack_effect.py +++ /dev/null @@ -1,234 +0,0 @@ -import dis -import opcode -import sys -from typing import NewType, TypeAlias, TypeVar -from collections.abc import Callable -from collections.abc import Iterable - -from types import EllipsisType - -from thunder.core.utils import FrozenDict - -__all__ = ("stack_effect_detail", "fill_ellipses") - -T = TypeVar("T") -Pop = NewType("Pop", int) -Push = NewType("Push", int) -StackEffect: TypeAlias = tuple[Pop, Push] | tuple[Pop, tuple[Push, Push]] - -# Aliases for common cases -NoStackEffect = (Pop(0), Push(0)) -PushTOS = (Pop(0), Push(1)) -PopTOS = (Pop(1), Push(0)) -ReplaceTOS = (Pop(1), Push(1)) -BinaryOp = (Pop(2), Push(1)) - - -def make_function_detail(*args: int) -> Callable[[int], StackEffect]: - return lambda oparg: (Pop(2 + sum((oparg & flag) != 0 for flag in args)), Push(1)) - - -def fill_ellipses(**kwargs: T | EllipsisType) -> Iterable[tuple[str, T]]: - prior_effect: T | EllipsisType = Ellipsis - for opname, effect in kwargs.items(): - if effect is Ellipsis: - effect = prior_effect - assert effect is not Ellipsis - prior_effect = effect - yield opname, effect - - -__EFFECTS = dict[str, StackEffect | Callable[[int], StackEffect] | EllipsisType]( - NOP=NoStackEffect, # ∅ -> ∅ - EXTENDED_ARG=NoStackEffect, - # - # Stack manipulation - POP_TOP=PopTOS, # A -> ∅ - ROT_TWO=(Pop(2), Push(2)), # A,B -> B,A - ROT_THREE=(Pop(3), Push(3)), # A,B,C -> C,A,B - ROT_FOUR=(Pop(4), Push(4)), # A,B,C,D -> D,A,B,C - ROT_N=lambda oparg: (Pop(oparg), Push(oparg)), # A,B,...,Z -> Z,A,B,... - DUP_TOP=(Pop(1), Push(2)), # A -> A,A - DUP_TOP_TWO=(Pop(2), Push(4)), # A,B -> A,B,A,B - UNPACK_SEQUENCE=lambda oparg: (Pop(1), Push(oparg)), # A -> B,C,... - # - # Jumps & return - JUMP_FORWARD=NoStackEffect, # ∅ -> ∅ - JUMP_ABSOLUTE=..., - POP_JUMP_IF_FALSE=PopTOS, # A -> ∅ - POP_JUMP_IF_TRUE=..., - RETURN_VALUE=..., - JUMP_IF_NOT_EXC_MATCH=BinaryOp, # A,B -> ∅ - # - # Exceptions and context managers: - POP_BLOCK=NoStackEffect, # ∅ -> ∅ - POP_EXCEPT=(Pop(3), Push(0)), # A, B, C -> ∅ - RERAISE=..., - RAISE_VARARGS=lambda oparg: (Pop(oparg), Push(0)), # A,B,... -> ∅ - WITH_EXCEPT_START=(Pop(7), Push(8)), # ??!? - LOAD_ASSERTION_ERROR=PushTOS, # ∅ -> A - # - # Variable manipulation - LOAD_CONST=PushTOS, # ∅ -> A - LOAD_FAST=..., - LOAD_GLOBAL=..., - LOAD_NAME=..., - STORE_FAST=PopTOS, # A -> ∅ - STORE_GLOBAL=..., - STORE_NAME=..., - DELETE_FAST=NoStackEffect, # ∅ -> ∅ - DELETE_GLOBAL=..., - DELETE_NAME=..., - # - # Attributes - LOAD_METHOD=(Pop(1), Push(2)), # A -> B,A - LOAD_ATTR=ReplaceTOS, # A -> B - STORE_ATTR=(Pop(2), Push(0)), # A, B -> ∅ - DELETE_ATTR=PopTOS, # A -> ∅ - # - # Closures - LOAD_CLOSURE=PushTOS, # ∅ -> A - LOAD_DEREF=..., - LOAD_CLASSDEREF=..., - STORE_DEREF=PopTOS, # A -> ∅ - DELETE_DEREF=NoStackEffect, # ∅ -> ∅ - # - # Functions and calls A,B,... -> Z - CALL_FUNCTION=lambda x: (Pop(x + 1), Push(1)), - CALL_METHOD=lambda x: (Pop(x + 2), Push(1)), - CALL_FUNCTION_KW=..., - CALL_FUNCTION_EX=make_function_detail(0x01), - MAKE_FUNCTION=make_function_detail(0x01, 0x02, 0x04, 0x08), - # - # Build containers A,B,... -> Z - BUILD_TUPLE=lambda oparg: (Pop(oparg), Push(1)), - BUILD_LIST=..., - BUILD_SET=..., - BUILD_STRING=..., - BUILD_MAP=lambda oparg: (Pop(oparg * 2), Push(1)), - BUILD_CONST_KEY_MAP=lambda x: (Pop(x + 1), Push(1)), - LIST_TO_TUPLE=ReplaceTOS, # A -> B - # - # Insertion leaves container on the stack A,B -> A - SET_ADD=BinaryOp, - SET_UPDATE=..., - LIST_APPEND=..., - LIST_EXTEND=..., - DICT_MERGE=..., - DICT_UPDATE=..., - MAP_ADD=(Pop(3), Push(1)), # A,B,C -> A - COPY_DICT_WITHOUT_KEYS=(Pop(2), Push(2)), # A,B -> A,C (I am unsure...) - # - # Unary operators A -> B - UNARY_POSITIVE=ReplaceTOS, - UNARY_NEGATIVE=..., - UNARY_NOT=..., - UNARY_INVERT=..., - # - # Binary operators A,B -> C - BINARY_POWER=BinaryOp, - BINARY_MULTIPLY=..., - BINARY_MATRIX_MULTIPLY=..., - BINARY_MODULO=..., - BINARY_ADD=..., - BINARY_SUBTRACT=..., - BINARY_SUBSCR=..., - BINARY_FLOOR_DIVIDE=..., - BINARY_TRUE_DIVIDE=..., - INPLACE_FLOOR_DIVIDE=..., - INPLACE_TRUE_DIVIDE=..., - INPLACE_ADD=..., - INPLACE_SUBTRACT=..., - INPLACE_MULTIPLY=..., - INPLACE_MATRIX_MULTIPLY=..., - INPLACE_MODULO=..., - BINARY_LSHIFT=..., - BINARY_RSHIFT=..., - BINARY_AND=..., - BINARY_XOR=..., - BINARY_OR=..., - COMPARE_OP=..., - IS_OP=..., - CONTAINS_OP=..., - # - # Binary operators (inplace) - # https://docs.python.org/3/reference/datamodel.html?highlight=iadd#object.__iadd__ - # "... and return the result (which could be, but does not have to be, self)." - INPLACE_POWER=BinaryOp, - INPLACE_LSHIFT=..., - INPLACE_RSHIFT=..., - INPLACE_AND=..., - INPLACE_XOR=..., - INPLACE_OR=..., - # - # Indexing operators - STORE_SUBSCR=(Pop(3), Push(0)), # A,B,C -> ∅ - DELETE_SUBSCR=(Pop(2), Push(0)), # A,B -> ∅ - BUILD_SLICE=lambda x: (Pop(x), Push(1)), # A,B,... -> Z - UNPACK_EX=lambda x: (Pop(1), Push((x & 0xFF) + (x >> 8) + 1)), # A -> B,C,... - # - # Iterators - GET_ITER=ReplaceTOS, # A -> B - GET_YIELD_FROM_ITER=ReplaceTOS, - # - # Misc. - FORMAT_VALUE=lambda oparg: (Pop(1 + bool(oparg & 0x04)), Push(1)), # (A?),B -> C - PRINT_EXPR=PopTOS, # A -> ∅ - IMPORT_STAR=..., - LOAD_BUILD_CLASS=PushTOS, - SETUP_ANNOTATIONS=NoStackEffect, - GET_LEN=(Pop(1), Push(2)), - IMPORT_NAME=BinaryOp, - IMPORT_FROM=(Pop(1), Push(2)), - MATCH_CLASS=(Pop(3), Push(1)), - MATCH_MAPPING=(Pop(1), Push(2)), - MATCH_SEQUENCE=..., - MATCH_KEYS=(Pop(2), Push(3 + bool(sys.version_info < (3, 11)))), - # - # Jump dependent - FOR_ITER=(Pop(1), (Push(2), Push(0))), - SETUP_WITH=(Pop(1), (Push(2), Push(7))), - SETUP_FINALLY=(Pop(0), (Push(0), Push(6))), - SETUP_ASYNC_WITH=(Pop(0), (Push(0), Push(6))), - # - # NOTE: These instructions have been removed since they are extraneous special cases. - # https://github.com/faster-cpython/ideas/issues/567 - # https://github.com/python/cpython/issues/102859 - JUMP_IF_TRUE_OR_POP=(Pop(1), (Push(0), Push(1))), - JUMP_IF_FALSE_OR_POP=..., - # - # TODO(robieta, t-vi): Iterators and generators - # "GEN_START": PopTOS, # Where does TOS for this come from? - # "YIELD_VALUE": ReplaceTOS, # I think - # "YIELD_FROM": (2, PushNew), # I am very unsure - # "GET_AWAITABLE": (1, 1), - # "BEFORE_ASYNC_WITH": (1, 2), - # "GET_AITER": (1, 1), - # "GET_ANEXT": (1, 2), - # "END_ASYNC_FOR": (7, 0), -) - - -# Split so MyPy can type check `__EFFECTS` without having to go through `fill_ellipses`. -_RAW_STACK_EFFECTS = FrozenDict[str, StackEffect | Callable[[int], StackEffect]](fill_ellipses(**__EFFECTS)) -del __EFFECTS - - -def stack_effect_detail(instruction: dis.Instruction) -> tuple[Pop, tuple[Push, Push]]: - assert isinstance(instruction, dis.Instruction), instruction - if callable(effect := _RAW_STACK_EFFECTS[instruction.opname]): - assert instruction.arg is not None - effect = effect(instruction.arg) - - assert isinstance(effect, tuple) and len(effect) == 2 and isinstance(effect[0], int) - if isinstance(effect[1], int): - effect = (effect[0], (effect[1],) * 2) - - # Python exposes a method to compute stack effect, so while it's not part - # of the public API we may as well use it to check our bookkeeping. - pop, (push_nojump, push_jump) = effect - for jump, push in ((False, push_nojump), (True, push_jump)): - expected = opcode.stack_effect(instruction.opcode, instruction.arg, jump=jump) - assert expected == push - pop, (expected, push, pop, jump) - - return Pop(pop), (Push(push_nojump), Push(push_jump)) diff --git a/thunder/core/script/passes.py b/thunder/core/script/passes.py deleted file mode 100644 index a53581ec01..0000000000 --- a/thunder/core/script/passes.py +++ /dev/null @@ -1,932 +0,0 @@ -import dis -import copy -import inspect -import opcode -import sys -import types -from typing import Any, Dict, List, Tuple, Union -from collections.abc import Callable -from collections.abc import Hashable -from contextvars import ContextVar - -import networkx as nx -import torch # # aehem. - -import thunder -from thunder.core.script.frontend import acquire_method, remove_unused_values -from thunder.core.script.graph import ( - assert_block, - assert_node, - assert_value, - Graph, - Block, - clone_blocks, - _generate_raises, - GraphObject, - Node, - PhiValue, - replace_values, - SourceInformation, - _Undefined, - Value, - repr_source_location, -) -from thunder.core.script.instrumentation import verbose_error, record -from thunder.core.script.parse import ThunderInstruction, JUMP_ABSOLUTE -from thunder.core.script.python_ir_data import get_instruction, X_THUNDER_STORE_ATTR -from thunder.torch import _torch_to_thunder_complete_map -from thunder.core.script.noinline import NOINLINE_METHODS -from thunder.core.utils import debug_asserts_enabled, debug_asserts_level, OrderedSet - -MAX_INLINE_ITERS = 50 - - -def split_block(gr: "Graph", bl: "Block", n: "Node") -> Block: - # The admin involved: - # - create a new "bottom block", the input block is the "top block" - # - split the .nodes - # - block_inputs of the top block and block_outputs of the bottom are the original - # block_inputs and block_outputs - # - scan all the node inputs and block_outputs of the lower part to see - # which need to be block_inputs of the lower block and thus outputs of the top one - # - define outputs of the "top block" to be the required inputs - # - add the input PhiValues and replace the outputs of the top block with them in the - # uses in the bottom block - # - add unconditional jump from top to bottom part - - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - i = 0 - while i < len(gr.blocks) and gr.blocks[i] is not bl: - i += 1 - assert i < len(gr.blocks), "block not found" - j = 0 - while j < len(bl.nodes) and bl.nodes[j] is not n: - j += 1 - assert j < len(bl.nodes), "node not found" - nbl = Block() - nbl.nodes = bl.nodes[j:] - del bl.nodes[j:] - old_block_outputs = bl.block_outputs - nbl.block_outputs = OrderedSet() - bl.block_outputs = OrderedSet() - nbl.block_inputs = [] - - bl_jump_node = Node(i=ThunderInstruction.make_jump_absolute(arg=None), inputs=[], outputs=[]) - bl_jump_node.jump_targets = [nbl] - if bl.nodes: - bl_jump_node.source_infos = copy.deepcopy(bl.nodes[-1].source_infos) - else: - bl_jump_node.source_infos = copy.deepcopy(nbl.nodes[0].source_infos) - bl.nodes.append(bl_jump_node) - nbl.jump_sources.append(bl_jump_node) - nbl.graph = gr - gr.blocks.insert(i + 1, nbl) - - potential_bl_outputs = {i for i in bl.block_inputs} - for n in bl.nodes: - for o in n.outputs: - potential_bl_outputs.add(o) - for i in bl.block_inputs: - potential_bl_outputs.add(i) - value_map: dict[GraphObject, GraphObject] = {} - - def get_or_create_phi(v: Value) -> Value: - if v in value_map: - return assert_value(value_map[v]) - if v.is_const or v.is_global: - return v - if v in potential_bl_outputs: # priority follow parent vs. phi_value? - phi_value = PhiValue([v], [bl_jump_node], nbl) - nbl.block_inputs.append(phi_value) - bl.block_outputs.add(v) - value_map[v] = phi_value - return phi_value - if v.parent is not None: - # this adds v.parent to the value_map, so that is used - # for the clone's parent - get_or_create_phi(v.parent) - v_new = v.clone(translation_dict=value_map) - v_new.block = nbl - return v_new - raise ValueError(f"unknwn value {v}") - - for n in nbl.nodes: - n.inputs = [get_or_create_phi(i) for i in n.inputs] - for o in n.outputs: - o.block = nbl - value_map[o] = o - - for o in old_block_outputs: - if o not in value_map: - bl.block_outputs.add(o) - else: - assert value_map[o].block is nbl or ( - value_map[o].is_function_arg or value_map[o].is_global - ), f"value {repr(o)} mapped to {repr(value_map[o])} has block {gr.blocks.index(value_map[o].block)} instead of {gr.blocks.index(nbl)}" - nbl.block_outputs.add(value_map[o]) - if o is not value_map[o]: - for pv in o.phi_values[:]: - if pv.block is not nbl: - pv.replace_value(o, value_map[o]) - - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - - return nbl - - -@verbose_error -def find_method_through_phi_parent(fn_value: Value) -> tuple[Value, list[str]]: - Point = tuple[Value, tuple[str, ...]] - to_process: list[Point] = [(v, ()) for v in fn_value.resolve()] - edges: OrderedSet[tuple[Point, Point]] = OrderedSet(((fn_value, ()), i) for i in to_process) - while to_process: - v, attr = to_process.pop() - destination = (v, attr) - if (parent := v.parent) is not None and (name := v.name) is not None: - destination = (parent, (name, *attr)) - - elif (node := v.node) is not None and node.i.opname == "BINARY_SUBSCR" and node.inputs[1].is_const: - destination = (node.inputs[0], (repr(node.inputs[1].value), *attr)) - - for vi in destination[0].resolve(): - edge = ((v, attr), (vi, destination[1])) - if edge not in edges: - edges.add(edge) - to_process.append(edge[1]) - - G = nx.from_edgelist(edges, nx.DiGraph) - G.remove_edges_from(nx.selfloop_edges(G)) - assert nx.is_connected(G.to_undirected()) - assert nx.is_directed_acyclic_graph(G) - - # A size one topological generation means all flow must pass through that node. Thus, the latest - # generation with that property is the farthest we can resolve attributes. - *_, (fn_value, attr_lookups) = (i for i, *other in nx.topological_generations(G) if not other) - return fn_value, list(attr_lookups) - - -def find_and_evaluate_method_through_phi_parent(v: Value) -> object | Callable: - fn_parent_value, attr_lookups = find_method_through_phi_parent(v) - if fn_parent_value.value is None: - return None - fn_value = fn_parent_value.value - for al in attr_lookups: - value = getattr(fn_value, al, _Undefined) - if value is _Undefined: - return _Undefined(fn_value, al) - fn_value = value - return fn_value - - -class SkipInlineError(NotImplementedError): - pass - - -@record(delegate_to="n") -def inline_method_call(gr: "Graph", n: "Node") -> None: - gr.ensure_links() - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - found_block = False - for i_bl, bl in enumerate(gr.blocks): - for i_n, n1 in enumerate(bl.nodes): - if n1 is n: # is? - found_block = True - break - if found_block: - break - assert found_block - if n.i.opname == "CALL_METHOD": - fn_value: Callable = find_and_evaluate_method_through_phi_parent(n.inputs[0]) # type: ignore - assert not isinstance(fn_value, _Undefined) - if fn_value is None: - raise NotImplementedError("cannot inline non-explicit function") - - ## TODO: value for self arg in Method calls? - ### in general: What is with callables here? - if isinstance(fn_value, torch.nn.Module): - mod1: object = fn_value - value_for_self1 = n.inputs[0] - fn_value = fn_value.forward - elif isinstance(fn_value, types.MethodType): - mod1 = fn_value.__self__ - value_for_self1 = n.inputs[1] - else: - mod1 = None - value_for_self1 = None - - if inspect.isbuiltin(fn_value): - raise NotImplementedError("cannot inline built-in (C-implemented) function") - elif n.i.opname in {"CALL_FUNCTION", "CALL_FUNCTION_KW"}: - fn_value = find_and_evaluate_method_through_phi_parent(n.inputs[0]) # type: ignore - assert not isinstance(fn_value, _Undefined) - if fn_value is None: - raise NotImplementedError("cannot inline non-explicit function") - - if isinstance(fn_value, torch.nn.Module): - mod1 = fn_value - value_for_self1 = n.inputs[0] - fn_value = fn_value.forward - else: - if isinstance(fn_value, types.FunctionType): - mod1 = None - value_for_self1 = None - elif isinstance(fn_value, types.MethodType): - mod1 = fn_value.__self__ - value_for_self1 = n.inputs[0].parent - assert value_for_self1 is not None - else: - source_str = repr_source_location(gr, n.source_infos) - raise NotImplementedError(f"inlining {fn_value} in instruction {n} at\n{source_str}") - else: - raise NotImplementedError(f"inlining {n}") - - # splitting must be done before replacing values, but this is changed even if we don't inline... - nbl = split_block(gr, bl, bl.nodes[i_n + 1]) - - gr1 = acquire_method(fn_value, module=mod1, mro_klass=gr.mro_klass if mod1 == gr.module else None) - for gr1_n in gr1.nodes(): - assert gr1_n.source_infos - have_generated = False - for si in gr1_n.source_infos: - si.gen_line_no = si.gen_line_no + len(gr.source_lines) + 1 - si.gen_end_line_no = si.gen_end_line_no + len(gr.source_lines) + 1 - # prepend - gr1_n.source_infos[:0] = copy.deepcopy(n.source_infos) - gr.source_lines.append("\n") - gr.source_lines += gr1.source_lines - - if gr1.ismethod: - sig1 = inspect.signature(gr1.method.__func__) - else: - sig1 = inspect.signature(gr1.method) - # transform defaults - sig1 = sig1.replace( - parameters=[ - p - if p.default is inspect._empty - else p.replace(default=Value(name=p.name, typ=type(p.default), value=p.default, is_const=True)) - for p in sig1.parameters.values() - ] - ) - - if gr1.ismethod: - call_args = [value_for_self1] - else: - call_args = [] - - if n.i.opname == "CALL_METHOD": - call_args += n.inputs[2:] - call_kwargs: dict[str, Any] = {} - elif n.i.opname == "CALL_FUNCTION": - call_args += n.inputs[1:] - call_kwargs = {} - elif n.i.opname == "CALL_FUNCTION_KW": - assert n.inputs[-1].is_const - num_kwargs = len(n.inputs[-1].value) - call_kwargs = {k: v for k, v in zip(n.inputs[-1].value, n.inputs[-1 - num_kwargs : -1])} - call_args += n.inputs[1 : -1 - num_kwargs] - else: - raise NotImplementedError() - - # TODO: catch and translate error messages, check types(?) - bound_args = sig1.bind(*call_args, **call_kwargs) - bound_args.apply_defaults() - - gr1_varargs = [n for n, p in sig1.parameters.items() if p.kind == p.kind.VAR_POSITIONAL] - gr1_varkwargs = [n for n, p in sig1.parameters.items() if p.kind == p.kind.VAR_KEYWORD] - ## TODO: TRANSLATE args (=tuple of Values) and kwargs (=dict str->Value) to a Value to something Value of ... (probably needs at least BUILD_TUPLE etc) - if gr1_varargs or gr1_varkwargs: - raise SkipInlineError("varargs and kwargs are currently not implemented") - - n1 = bl.nodes.pop(i_n) - assert n1 is n - - # there should be exactly one - (ret_bl,) = (bl for bl in gr1.blocks if len(bl.nodes) > 0 and bl.nodes[-1].i.opname == "RETURN_VALUE") - - ret_node = ret_bl.nodes[-1] - ret_node.i = ThunderInstruction.make( - JUMP_ABSOLUTE, - arg=-1, - argrepr="None", - offset=ret_node.i.offset, - starts_line=ret_node.i.starts_line, - is_jump_target=ret_node.i.is_jump_target, - line_no=ret_node.i.line_no, - ) - bl.nodes[-1].jump_targets = [gr1.blocks[0]] - assert len(gr1.blocks[0].jump_sources) == 1 - gr1.blocks[0].jump_sources = [bl.nodes[-1]] - for pv in gr1.blocks[0].block_inputs: - assert pv.jump_sources == [None] - pv.jump_sources = [bl.nodes[-1]] - ret_node.jump_targets = [nbl] - nbl.jump_sources = [ret_node if js == bl.nodes[-1] else js for js in nbl.jump_sources] - for pv in nbl.block_inputs: - pv.jump_sources = [ret_node if js == bl.nodes[-1] else js for js in pv.jump_sources] - - for bl1 in gr1.blocks: - bl1.graph = gr - gr.blocks[i_bl + 1 : i_bl + 1] = gr1.blocks - - assert len(n.outputs) == 1 - inp_map = {p: bound_args.arguments[p.name] for p in gr1.local_variables_at_start if p.name in bound_args.arguments} - if n.outputs[0] in bl.block_outputs: # it may legitimately happen that we don't use the output - bl.block_outputs.remove(n.outputs[0]) # TODO: what with inplace!! - bl.block_outputs.update(inp_map.values()) # Note: This includes default args - gr.ensure_links() - replace_values(gr1, inp_map) - - # output value - rv = ret_node.inputs.pop() - assert not ret_node.inputs - (orv,) = n.outputs - replace_values(gr, {orv: rv}) - ret_bl.block_outputs.add(rv) - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - - -def inline_submodule_calls(gr: "Graph") -> bool: - # inlines submodule calls - # returns whether something has changed - # TODO: recursively and not from nested structures (ModuleList etc.) - changed = False - gr.ensure_links() - for bl in gr.blocks[:]: - for n in bl.nodes[:]: - if n.i.opname in {"CALL_METHOD", "CALL_FUNCTION", "CALL_FUNCTION_KW"}: - fn_value = find_and_evaluate_method_through_phi_parent(n.inputs[0]) - if isinstance(fn_value, _Undefined): - # TODO: We could insert a RAISE here if we then delete the return - # value and all (direct or indirect) uses. - methval = Value( - value=_generate_raises( - f"attribute error '{type(fn_value.value)}' object has no attribute '{fn_value.attr}'" - ), - is_const=True, - ) - n.i = n.i.modify_copy(opname="CALL_FUNCTION", arg=0, opcode=None) - n.inputs = [methval] - if isinstance(fn_value, torch.nn.Module) or ( - inspect.ismethod(fn_value) - and isinstance(fn_value.__self__, torch.nn.Module) - and (fn_value not in NOINLINE_METHODS.get()) - ): - inline_method_call(gr, n) - changed = True - - return changed - - -def strongly_inline_functions(gr: "Graph") -> None: - for _ in range(MAX_INLINE_ITERS): - loop = False - gr.ensure_links() - for bl in gr.blocks[:]: - for n in bl.nodes[:]: - if n.i.opname in {"CALL_METHOD", "CALL_FUNCTION", "CALL_FUNCTION_KW"}: - fn_value = find_and_evaluate_method_through_phi_parent(n.inputs[0]) - if ( - fn_value is not None - and not inspect.isbuiltin(fn_value) - and isinstance(fn_value, types.FunctionType) - and fn_value not in _torch_to_thunder_complete_map - and fn_value not in NOINLINE_METHODS.get() - ): - ## handle methods or nn.Modules / other classes? - try: - inline_method_call(gr, n) - loop = True - except SkipInlineError: - pass - except RuntimeError as e: - (msg,) = e.args - source_str = repr_source_location(gr, n.source_infos) - msg = f"{msg}\nwhile inlining:\n{source_str}" - e.args = (msg,) - raise e - if not loop: - return - - raise AssertionError(f"Inlining did not complete after {MAX_INLINE_ITERS} passes.") - - -def torch_to_thunder(gr: "Graph", fallback: bool = False) -> None: - """replaces calls to torch.foo functions with calls into thunder's torch language.""" - - def fill_in_value(v: Value, seen: OrderedSet[Value]) -> None: - if v in seen: - return - seen.add(v) - parent = v.parent - if parent is None and isinstance(v, PhiValue): - for vv in v.values: - fill_in_value(vv, seen) - for vv in v.values[1:]: - if vv.value is not v.values[0].value: - return - v.value = v.values[0].value - if v.value is None and parent is not None: - fill_in_value(parent, seen) - if v.name is None and isinstance(v, PhiValue) and parent is not None and parent.name is not None: - v.name = parent.name - if v.value is None and parent is not None and parent.value is not None and v.name is not None: - v.value = getattr(parent.value, v.name, None) - - for bl in gr.blocks: - for n in bl.nodes: - for idx, i in enumerate(n.inputs): - done = False - fill_in_value(i, OrderedSet()) - i_or_parent = i - while ( - not isinstance(i_or_parent.value, Hashable) - or i_or_parent.value not in _torch_to_thunder_complete_map - ) and i_or_parent.parent is not None: - i_or_parent = i_or_parent.parent - - if isinstance(i_or_parent.value, Hashable) and i_or_parent.value in _torch_to_thunder_complete_map: - i_or_parent.value = _torch_to_thunder_complete_map[i.value] - # we reinstantiate because we don't want a PhiValue here - i_new = Value( - value=i_or_parent.value, - typ=type(i_or_parent.value), - parent=None, - is_const=True, - is_global=False, - name=i_or_parent.name, - ) - n.inputs[idx] = i_new - if n.i.opname == "CALL_METHOD" and idx == 0: - # todo get others, too - n.i = get_instruction(opname="CALL_FUNCTION", arg=n.i.arg) - del n.inputs[1] - done = True - - if (not done) and fallback: # fallback - # todo: change name?, deeper nesting? - if i.value == torch: - i.value = thunder.langs.torch - if i.parent is not None and i.parent.value == torch: - i.parent.value = thunder.langs.torch - assert i.name is not None - i.value = getattr(thunder.langs.torch, i.name) - - # replace other things by checking against torch module (make dict at startup?) - name = getattr(i.value, "__name__", None) - tf = None - if name is not None: - tf = getattr(torch, name, None) - if tf is not None and i.value == tf: - i.value = getattr(thunder.langs.torch, name) - i.is_global = False - i.is_const = True - - -def merge_two_blocks(gr: "Graph", bl1: "Block") -> None: - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - jt = bl1.nodes[-1].jump_targets - if len(jt) != 1: - raise RuntimeError("can only fuse blocks with deterministic connection") - bl2 = jt[0] - if len(bl2.jump_sources) != 1 or bl2.jump_sources[0] != bl1.nodes[-1]: - raise RuntimeError("second block to be fused must only have first block as jump source") - - replacements: dict[Value, Value] = {} - for i in bl2.block_inputs: - assert isinstance(i, PhiValue) and len(i.values) == 1, (i, getattr(i, "values", None)) - (iv,) = i.values - if iv in bl1.block_outputs: - replacements[i] = iv - else: - if i.jump_sources == [bl1.nodes[-1]]: - i.jump_sources = [iv.block.nodes[-1]] - bl1.block_inputs.append(i) - i.block = bl1 - - replace_values(bl2, replacements, follow_phi_values=True) - # TODO: Should this happen automatically in replace_values? - # Should we also replace values in bl1? - for o in bl1.block_outputs: - for pv in o.phi_values[:]: - if pv in replacements: - pv.remove_value(o) - else: - pv.jump_sources = [js if js != bl1.nodes[-1] else bl2.nodes[-1] for js in pv.jump_sources] - - bl1_jump = bl1.nodes[-1] - bl2_jump = bl2.nodes[-1] - - bl1.block_outputs = OrderedSet(o for o in bl1.block_outputs if o.phi_values) - bl1.block_outputs.update(bl2.block_outputs) - - bl1.nodes[-1:] = bl2.nodes - gr.blocks.remove(bl2) - - gr.ensure_links() - - # fix jump sources in other blocks - for bl in gr.blocks: - for i in bl.block_inputs: - i.jump_sources = [(bl2_jump if js is bl1_jump else js) for js in i.jump_sources] - - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - - -def merge_blocks_where_possible(gr: "Graph") -> None: - i_bl = 0 - while i_bl < len(gr.blocks): - bl1 = gr.blocks[i_bl] - jt = bl1.nodes[-1].jump_targets - if len(jt) == 1: - bl2 = jt[0] - else: - bl2 = None - if bl2 is not None and len(bl2.jump_sources) == 1 and bl2.jump_sources[0] == bl1.nodes[-1]: - merge_two_blocks(gr, bl1) - else: - i_bl += 1 - - -def find_blocks_of_for(gr: "Graph", for_block: "Block") -> list[Block]: - assert for_block.nodes[-1].i.opname == "FOR_ITER" - - blocks_of_for_loop = OrderedSet({for_block}) - currently_looking_at = set() - - def find_blocks_of_for_rec(for_block: "Block", start_block: "Block") -> bool: - if for_block == start_block: - return True - if start_block in currently_looking_at: - return False - currently_looking_at.add(start_block) - found = False - for jt in start_block.nodes[-1].jump_targets: - found |= find_blocks_of_for_rec(for_block, jt) - currently_looking_at.remove(start_block) - if found: - blocks_of_for_loop.add(start_block) - return found - - find_blocks_of_for_rec(for_block, for_block.nodes[-1].jump_targets[0]) - return list(blocks_of_for_loop) - - -def unroll_for_over_modules(gr: "Graph", for_iter_node: "Node") -> None: - gr.ensure_links() - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - get_iter_node = for_iter_node.inputs[0].values[0].node - assert get_iter_node.i.opname == "GET_ITER" - - iterated_module_list_parent, attr_lookups = find_method_through_phi_parent(get_iter_node.inputs[0]) - assert iterated_module_list_parent.value is not None - iterated_module_list = iterated_module_list_parent.value - for al in attr_lookups: - iterated_module_list = getattr(iterated_module_list, al) - - # what about more complex things? - assert isinstance(iterated_module_list, (torch.nn.Sequential, torch.nn.ModuleList)) - - for_loop_len = len(iterated_module_list) - for_iter_block = for_iter_node.block - assert for_iter_block is not None - get_iter_block = get_iter_node.block - - (iter_v,) = get_iter_node.outputs - (iter_phi,) = for_iter_node.inputs - - assert isinstance(iter_phi, PhiValue) - assert iter_v in iter_phi.values - - ### first we find the blocks of the for loop - bls = find_blocks_of_for(gr, for_iter_block) - - jmp_nodes = {bl.nodes[-1] for bl in bls} - assert all((v is iter_v or js in jmp_nodes) for v, js in zip(iter_phi.values, iter_phi.jump_sources)) - - for_iter_node.i = get_instruction(opname="BINARY_SUBSCR", arg=None) - iter_phi.remove_value(iter_v) - assert len(iter_v.phi_values) == 0 - get_iter_block.block_outputs.remove(iter_v) - - get_iter_block.block_outputs.add(get_iter_node.inputs[0]) - - seen = set() - - def delete_value_and_sources(v: Value) -> None: - # check that it is possible? - if v in seen: - return - seen.add(v) - if isinstance(v, PhiValue): - for vv, js in zip(v.values, v.jump_sources): - delete_value_and_sources(vv) - assert js is not None and js.block is not None - js.block.block_outputs.remove(vv) - v.block.block_inputs.remove(v) - - delete_value_and_sources(iter_phi) - seq_phi = PhiValue(values=[get_iter_node.inputs[0]], jump_sources=[get_iter_block.nodes[-1]], block=for_iter_block) - get_iter_block.nodes.remove(get_iter_node) - for_iter_block.block_inputs.append(seq_phi) - - idx = Value(value=0, is_const=True) - for_iter_node.inputs = [seq_phi, idx] - for_iter_node.outputs = [for_iter_node.outputs[1]] - - for_iter_block_jmp = Node(i=get_instruction(opname="JUMP_ABSOLUTE", arg=None)) - for_iter_block_jmp.source_infos = copy.deepcopy(for_iter_node.source_infos) - for_iter_block.nodes.append(for_iter_block_jmp) - for_iter_block_jmp.jump_targets = [for_iter_node.jump_targets[0]] - for_iter_node_exit_jump_target = for_iter_node.jump_targets[1] - for_iter_node.jump_targets = [] - for_iter_block_jmp.jump_targets[0].jump_sources = [ - (js if js is not for_iter_node else for_iter_block_jmp) - for js in for_iter_block_jmp.jump_targets[0].jump_sources - ] - - exit_block = Block() - gr.blocks.append(exit_block) - exit_node = Node(i=get_instruction(opname="JUMP_ABSOLUTE", arg=None)) - exit_node.source_infos = copy.deepcopy(for_iter_node.source_infos) - exit_node.jump_targets = [for_iter_node_exit_jump_target] - target_after_iter = exit_node.jump_targets[0] - exit_node.jump_targets[0].jump_sources = [ - (js if js is not for_iter_node else exit_node) for js in exit_node.jump_targets[0].jump_sources - ] - exit_block.nodes.append(exit_node) - for i in for_iter_block.block_inputs: - exit_block.block_inputs.append(PhiValue([], [], exit_block)) - - unroll_blocks: list[tuple[list[Block], dict[GraphObject, GraphObject]]] = [(list(bls), {})] - unroll_blocks += [clone_blocks(bls) for _ in range(1, for_loop_len)] - for idx, (nbls, td) in enumerate(unroll_blocks): - if idx > 0: - gr.blocks += nbls - v_idx = Value(value=idx, is_const=True) - assert_node(td[for_iter_node]).inputs[1] = v_idx - fin_o = assert_node(td[for_iter_node]).outputs[0] - assert fin_o.name is not None - fin_o.name += f"_{idx}" - else: - assert for_iter_node.outputs[0].name is not None - for_iter_node.outputs[0].name += "_0" - - gr.ensure_links() - - fixup_data = [] - for idx, (nbls, td) in enumerate(unroll_blocks): - if idx == 0: - fib_i = for_iter_block - jump_sources_to_fix = [js for js in for_iter_block.jump_sources if js is not get_iter_block.nodes[-1]] - else: - fib_i = assert_block(td[for_iter_block]) - jump_sources_to_fix = fib_i.jump_sources[:] - if idx + 1 < len(unroll_blocks): - _, td_next = unroll_blocks[idx + 1] - fib_next = assert_block(td_next[for_iter_block]) - else: - fib_next = exit_block - - fixup_data.append((fib_i, jump_sources_to_fix, fib_next, nbls)) - - for idx_it, (fib_i, jump_sources_to_fix, fib_next, nbls) in enumerate(fixup_data): - for js in jump_sources_to_fix: - assert js is not None - for idx, jt in enumerate(js.jump_targets): - if jt == fib_i: - js.set_jump_target(fib_next, idx=idx) - - for idx_i, i in enumerate(fib_i.block_inputs): - if any((js.block in nbls) for js in i.jump_sources): - ## if this is a variable updated in the loop: - ## - instead of looping back, point the update to the phi value of the next block (or the exit block) - ## - if idx > 0: remove external (before the loop) value - for v, js in zip(i.values[:], i.jump_sources[:]): - if js is not None and js.block not in nbls and idx_it > 0: - i.remove_value(v) - - for idx_it, (fib_i, jump_sources_to_fix, fib_next, nbls) in enumerate(fixup_data): - for idx_i, i in enumerate(fib_i.block_inputs): - if any((js is not None and js.block in nbls) for js in i.jump_sources): - for v, js in zip(i.values[:], i.jump_sources[:]): - if js is not None and assert_block(assert_node(js).block) in nbls: - i.remove_value(v) - assert_block(fib_next).block_inputs[idx_i].add_missing_value(v, jump_source=js) - if idx_it == 0: - for pv in i.phi_values[:]: - if pv.block is target_after_iter: - pv.remove_value(i) - pv.add_missing_value(exit_block.block_inputs[idx_i], jump_source=exit_node) - - for i in exit_block.block_inputs[:]: - if i.phi_values: - exit_block.block_outputs.add(i) - else: - assert isinstance(i, PhiValue) - for v in i.values[:]: - i.remove_value(v) - exit_block.block_inputs.remove(i) - if debug_asserts_enabled(): - thunder.core.script.graph.check_graph(gr) - - -def find_and_unroll_for_loop(gr: "Graph") -> bool: - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - gr.ensure_links() - - for bl in gr.blocks[:]: - for n in bl.nodes[:]: - if n.i.opname == "FOR_ITER": - for_iter_node = n - get_iter_node = for_iter_node.inputs[0].values[0].node - if get_iter_node.i.opname == "GET_ITER": - ( - iterated_module_list_parent, - attr_lookups, - ) = find_method_through_phi_parent(get_iter_node.inputs[0]) - if iterated_module_list_parent.value is None: - continue - iterated_module_list = iterated_module_list_parent.value - for al in attr_lookups: - iterated_module_list = getattr(iterated_module_list, al) - # what about more complex things? in particular enumerate, but zip, ... - if isinstance(iterated_module_list, (torch.nn.Sequential, torch.nn.ModuleList)): - thunder.core.script.passes.unroll_for_over_modules(gr, for_iter_node) - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - thunder.core.script.passes.merge_blocks_where_possible(gr) - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - return True - if debug_asserts_enabled(): - thunder.core.script.graph.check_graph(gr) - return False - - -def unroll_for_loops_and_inline_modules(gr: "Graph") -> None: - if debug_asserts_level() > 1: - thunder.core.script.graph.check_graph(gr) - iterate = True - while iterate: - iterate = find_and_unroll_for_loop(gr) - if not iterate: - iterate = inline_submodule_calls(gr) - if iterate: - thunder.core.script.passes.merge_blocks_where_possible(gr) - - -def module_to_function(gr: "Graph") -> tuple[list[str], list[torch.Tensor]]: - attr_dict: dict[str, int] = {} - attr_list: list[str] = [] - attr_values = [] - return_values: dict[str, Value] = {} # PhiValues in the return block - - if debug_asserts_enabled(): - thunder.core.script.graph.check_graph(gr) - - def functionalize_value_if_possible(i): - # TODO: inefficient because it looks twice - v = find_and_evaluate_method_through_phi_parent(i) - # assert not isinstance(v, _Undefined), f"undefined: {v.value} {v.attr}" - if isinstance(v, _Undefined): - return Value(value=v, is_const=True) - maybe_self, attrs = find_method_through_phi_parent(i) - - attr_string = ".".join(attrs) - if maybe_self.value is gr.module and (isinstance(v, torch.Tensor) or (attr_string in return_values)): - # the new attributes come directly after the self argument - idx = attr_dict.setdefault(attr_string, len(attr_list) + 1) - if idx == len(attr_list) + 1: - func_arg = Value(name=attr_string, is_function_arg=True) - gr.local_variables_at_start.insert(idx, func_arg) - attr_list.append(attr_string) - attr_values.append(v) - gr.co_argcount += 1 - # we need a default argument to be able to put the things at the end (but this will have to change for *args, **kwargs anyway... - # gr.func_defaults.append(None) - if attr_string in return_values: - return_values[attr_string].add_missing_value(func_arg) - else: - func_arg = gr.local_variables_at_start[idx] - - pvs = [pv for pv in func_arg.phi_values if pv.block is bl] - if not pvs: - pv = PhiValue([func_arg], [None], bl) - bl.block_inputs.append(pv) - else: - (pv,) = pvs - ## remove old input from phi_values etc? - return pv - if maybe_self.value is gr.module and ( - n.i.opname not in {"BINARY_SUBSCR"} and not isinstance(v, torch.nn.Module) - ): - ## inline to const... - i.value = v - i.typ = type(i.value) - i.parent = None - i.is_const = True - i.is_global = False - return None - return None - - return_block = None - for bl in gr.blocks: - if bl.nodes[-1].i.opname == "RETURN_VALUE": - assert return_block is None, "multiple return statements should not happen here" - return_block = bl - assert return_block is not None, "could not find return block" - - for bl in gr.blocks: - for n in bl.nodes: - if n.i.opname == "STORE_ATTR": - v = find_and_evaluate_method_through_phi_parent(n.inputs[1]) - if isinstance(v, _Undefined): - n.inputs[1] = Value(value=v, is_const=True) - continue - # assert not isinstance(v, _Undefined), f"undefined: {v.value} {v.attr}" - maybe_self, attrs = find_method_through_phi_parent(n.inputs[1]) - attrs.append(n.i.argval) - if maybe_self.value is gr.module: - attr_string = ".".join(attrs) - n.i = n.i.modify_copy(opname=X_THUNDER_STORE_ATTR, opcode=None, argval=attr_string) - pv = return_values.get(attr_string) - if pv is None: - pv = PhiValue([], [], return_block) - pv.name = attr_string - return_values[attr_string] = pv - return_block.block_inputs.append(pv) - v = Value(node=n, name=attr_string, block=bl) # disambiguate? - pv.add_missing_value(v, jump_source=bl.nodes[-1]) - n.outputs = [v] - bl.block_outputs.add(v) - del n.inputs[1] - - for bl in gr.blocks: - for n in bl.nodes: - if n.i.opname == "CALL_METHOD": - if n.inputs[0].parent == n.inputs[1]: - v = find_and_evaluate_method_through_phi_parent(n.inputs[0]) - if not isinstance(v, types.MethodType) or v.__self__ != find_and_evaluate_method_through_phi_parent( - n.inputs[1] - ): - # this case (not a proper method call is usually handled in executing the LOAD_METHOD opcode) - n.i = n.i.modify_copy(opname="CALL_FUNCTION", opcode=None) - del n.inputs[1] - - for idx_i, i in enumerate(n.inputs): - v = functionalize_value_if_possible(i) - if v is not None: - n.inputs[idx_i] = v - - bl.block_outputs = OrderedSet( - [v if (v := functionalize_value_if_possible(o)) is not None else o for o in bl.block_outputs] - ) - - if return_values: - bt_extra = Node( - i=get_instruction(opname="BUILD_TUPLE", arg=1 + len(return_values)), - source_infos=copy.deepcopy(return_block.nodes[-1].source_infos), - ) - bt_extra.inputs = return_block.nodes[-1].inputs + list(return_values.values()) - v_tuple_extra = Value(node=bt_extra, block=return_block) - bt_extra.outputs = [v_tuple_extra] - return_block.nodes.insert(-1, bt_extra) - return_block.nodes[-1].inputs = [v_tuple_extra] - - remove_unused_values(gr) - if gr.local_variables_at_start[0].phi_values: - gr.summary(print_lines=True) - raise RuntimeError( - """could not eliminate self argument - this most likely means that you are setting attributes in forward or using them - in an unexpected way that thunder does not yet support. - The problem lies in (indirect) uses of V_0 in the graph above.""" - ) - - # check to avoid assignments for both a.b and a.b.c - sorted_keys = sorted(return_values.keys()) # this uses that '.' sorts before other things - for i in range(len(sorted_keys) - 1): - kbase = sorted_keys[i] - knext = sorted_keys[i + 1] - if knext.startswith(kbase) and knext[len(kbase)] == ".": - # N.B. we know that knext is longer if kbase is a prefix so the knext[len(kbase)] above will not be out of bounds. - raise RuntimeError(f"Assigning to members of assigned members ('{kbase}' and '{knext}') is not supported.") - - del gr.local_variables_at_start[0] - gr.co_argcount -= 1 - if gr.co_posonlyargcount > 0: - gr.co_posonlyargcount -= 1 - - # thunder.core.script.graph.check_graph(gr) - # gr.summary(print_lines=True) - - return attr_list, attr_values, list(return_values.keys()) diff --git a/thunder/core/script/protograph.py b/thunder/core/script/protograph.py deleted file mode 100644 index 169e4bc60c..0000000000 --- a/thunder/core/script/protograph.py +++ /dev/null @@ -1,518 +0,0 @@ -from __future__ import annotations - - -import abc -import collections -import dataclasses -import functools -import inspect -import itertools -from types import CodeType -from typing import cast, overload, Any, Literal, NewType -from collections.abc import Iterable, Iterator, Mapping - -from thunder.core.script import algorithms, instrumentation, parse, values -from thunder.core.utils import debug_asserts_enabled, FrozenDict, OrderedSet - -__all__ = ("ProtoBlock", "ProtoGraph") - -# ============================================================================= -# == Inter-ProtoBlock abstract value flow ===================================== -# ============================================================================= -# -# ProtoBlocks are weakly coupled by design. The `VariableKey` slots allow edges -# to be deduced (e.g. `x` at the start of one block must be the same as `x` at -# the end of the prior block), but there's no strong requirement. (And indeed, -# the ProtoGraph immediately after parsing has all unconnected `AbstractRef`s -# for input values.) Similarly, ProtoGraph serves only to record organize the -# block topology, check invariants, and provide various helper methods. -# -# This weak coupling exists to facilitate graph rewrites and reduce the surface -# area for self-inconsistent representation. By readily discarding (deduced) -# information we don't need to carry invariants through complex passes; we can -# simply decouple the graph, perform whatever local modifications we like, and -# then reconnect everything. This representation is immutable (notwithstanding -# a few implementation details), so "decouple" means emitting a new erased -# graph. (Though simple value replacements can be done directly.) -JumpTarget = NewType("JumpTarget", tuple["ProtoBlock", parse.Jump]) -Uses = NewType("Uses", OrderedSet[parse.VariableKey]) - - -@dataclasses.dataclass(frozen=True, eq=False) -class ProtoBlock(instrumentation.InstrumentingBase): # type: ignore[misc,no-any-unimported] - """Stores abstract data flow for a code block.""" - - flow: values.IntraBlockFlow - jump_targets: tuple[JumpTarget, ...] = dataclasses.field(default=(), init=False) - uses: Uses = dataclasses.field(default_factory=lambda: Uses(OrderedSet()), init=False) - - def __repr__(self) -> str: - ops = "\n".join(f" {i.opname}" for i, _ in self.flow.symbolic) - return f"ProtoBlock: {hex(id(self))}\n{ops}" - - def __hash__(self) -> int: - return id(self) - - def __post_init__(self) -> None: - self.uses.update(self.flow.uses) - - def add_jump_target(self, other: ProtoBlock, jump: parse.Jump) -> None: - """We need to add jump targets after all ProtoBlocks are initialized.""" - - # Override `frozen=True` for this one limited use case. - object.__setattr__(self, "jump_targets", self.jump_targets + ((other, jump),)) - - -@dataclasses.dataclass(frozen=True, eq=False) -class ProtoGraph: - protoblocks: tuple[ProtoBlock, ...] - root: ProtoBlock - parents: Mapping[ProtoBlock, tuple[ProtoBlock, ...]] - - Provenance = values.ParsedSymbolic | tuple[type["ProtoGraphTransform"], "ProtoGraph"] - provenance: Provenance - - def __init__(self, protoblocks: Iterable[ProtoBlock], provenance: Provenance) -> None: - G = algorithms.TypedDiGraph[ProtoBlock]() - for protoblock in (protoblocks := tuple(protoblocks)): - is_return = tuple(protoblock.flow.symbolic)[-1][0].opname == parse.RETURN_VALUE - G.add_node(protoblock, is_return=is_return) - - for protoblock in protoblocks: - for destination, jump in protoblock.jump_targets: - G.add_edge(protoblock, destination, adjacent=not jump) - - assert protoblocks - object.__setattr__(self, "protoblocks", tuple(algorithms.sort_adjacent(G))) - assert len(G) == len(self.protoblocks) == len(protoblocks), (len(G), len(self.protoblocks), len(protoblocks)) - - object.__setattr__(self, "root", self.protoblocks[0]) - root_stack = [(k, v) for k, v in self.root.flow.begin_state if k.scope == parse.VariableScope.STACK] - assert not root_stack, f"Root block should not have stack inputs: {root_stack}" - - nodes = cast(Iterable[ProtoBlock], G.nodes) # For some reason mypy needs this. - parents = {protoblock: tuple(G.predecessors(protoblock)) for protoblock in nodes} - object.__setattr__(self, "parents", FrozenDict(parents)) - object.__setattr__(self, "provenance", provenance) - - @classmethod - def from_code(cls, co: CodeType) -> ProtoGraph: - """Given a method, disassemble it to a sequence of simple blocks.""" - parsed = values.ParsedSymbolic.make(parse.ParsedFunctional.make(co)) - protoblocks = tuple( - ProtoBlock(values.IntraBlockFlow(symbolic, begin, end)) for symbolic, begin, end in parsed.blocks - ) - for source, sink, jump in parsed.provenance.provenance.edges: - protoblocks[source].add_jump_target(protoblocks[sink], jump) - - return cls(protoblocks, parsed) - - def __iter__(self) -> Iterator[ProtoBlock]: - yield from self.protoblocks - - def __getitem__(self, index: int) -> ProtoBlock: - return self.protoblocks[index] - - def __len__(self) -> int: - return len(self.protoblocks) - - def __repr__(self) -> str: - return "\n\n".join(repr(protoblock) for protoblock in self) - - @property - def flat_flow(self) -> Iterable[tuple[parse.ThunderInstruction, values.Symbolic, values.Materialized]]: - for protoblock in self: - for instruction, symbolic in protoblock.flow.symbolic: - yield instruction, symbolic, protoblock.flow.materialized[instruction] - - @property - def is_linked(self) -> bool: - # NOTE: `is_linked` is vacuously True for a single block graph. - flat_begin = itertools.chain(*(i.flow._begin.values() for i in self if i is not self.root)) - return len(self) == 1 or any(not isinstance(i, values.AbstractRef) for i in flat_begin) - - def unlink(self) -> ProtoGraph: - return Unlink(self).apply(or_default=True) - - def link(self) -> ProtoGraph: - if result := ProtoGraphTransform.chain(self, AddTransitive, MatchStacks, Connect): - assert AddTransitive(result).apply(or_default=False) is None - assert (result or self).is_linked - return result or self - - def debug_print_protoflows(self) -> None: - """ - Print out the node_flow for each protoblock in the - protograph, in a way that's nice to read and debug with. - """ - - counter = 0 - idxes: dict[values.AbstractValue, int] = {} - for pb in self: - for node in pb.flow.materialized.values(): - for val in itertools.chain(node.inputs.ordered, node.outputs): - if val not in idxes.keys(): - idxes[val] = counter - counter += 1 - - def to_index_str(values: tuple[values.AbstractValue, ...]) -> str: - indices = (str(idxes[v]) for v in values) - return f"({', '.join(indices)})" - - for i, pb in enumerate(self): - print(f"Protoblock {i}:") - print(f"{'':>22}Inputs, Outputs") - for instruction, node in pb.flow.materialized.items(): - print(f" {instruction.opname:>20}, {to_index_str(node.inputs.ordered)} -> {to_index_str(node.outputs)}") - print("\n") - - -# ============================================================================= -# == Graph transforms (Base classes) ========================================== -# ============================================================================= -# ProtoGraphTransform -# ReplaceProtoBlocks -# ReplaceValues -# CondenseValues -# ReplaceSymbolic - - -class ProtoGraphTransform(abc.ABC): - """Handles mechanical portions of graph rewrites. - The base case is unopinionated; it simply accepts whatever new ProtoGraph is - emitted by `self._apply`. The primary feature it provides is checking. - - NOTE: - The convention adopted is for the pass logic to produce `T | None` - (e.g. `ProtoGraph | None`) where `None` signals that no change is - applicable. - - Forbid Linked: - A key invariant of ProtoGraph is that every AbstractValue has exactly - **one** producer, which is set by the symbolic flow. (With the exception - of `AbstractRef`s which are placeholders for an as-yet unspecified - AbstractValue.) However, within a block there is a flat list of - **concrete** values specifying the state at the start of the block. - - If one were to replace all instances of `X` in a ProtoGraph with `Y`, - this invariant would be preserved. On the other hand, if one were to - replace `X` with `Y` **only at the symbolic producer of `X`** then - downstream blocks could still have `X` as a block input, despite the - fact that `X` no longer has a producer. (Note that this is only a problem - across blocks; within blocks the materialization pass respects the update - and emits a consistent materialized state for the new ProtoBlock.) - - It is often convenient to simply rewrite the symbolic flow within a - single ProtoBlock. In that case the correct procedure is to generate an - unlinked ProtoGraph, perform the local rewrites, and then relink it. - (Where the connection pass will handle reconciliation automatically.) - - Check idempotence: - Nearly all passes are expected to be idempotent. This provides a good - deal of free test coverage since it produces both a test case (the result - of `self._apply`) and an expected result. (That `self._apply` returns - `None`.) We perform this check many times in order to flush out - non-deterministic passes. (Though the value is configurable if a pass is - particularly expensive.) - - However, given the potential added start up latency and possibility of - spurious failures this check is gated by `debug_asserts_enabled`, which - defaults to `False`. (Except for unit tests.) - """ - - _forbid_linked: bool = False - _kwargs: FrozenDict[str, Any] # Used to replay transform for `idempotent` check. - _idempotent_repeats: int = 10 - - @abc.abstractmethod - def _apply(self) -> ProtoGraph | None: - """Override this method to emit an (optional) new ProtoGraph.""" - ... - - def __new__(cls, *args: Any, **kwargs: Any) -> ProtoGraphTransform: - self = super().__new__(cls) - bound = inspect.signature(self.__class__.__init__).bind(None, *args, **kwargs).arguments - bound.pop("self") - bound.pop("proto_graph") - self._kwargs = FrozenDict(bound) - return self - - def __init__(self, proto_graph: ProtoGraph) -> None: - assert not (self._forbid_linked and len(proto_graph) > 1 and proto_graph.is_linked), self - assert isinstance(proto_graph, ProtoGraph) - self._protograph = proto_graph - - @property - def protograph(self) -> ProtoGraph: - return self._protograph - - @overload - def apply(self, or_default: Literal[False]) -> ProtoGraph | None: - ... - - @overload - def apply(self, or_default: Literal[True]) -> ProtoGraph: - ... - - def apply(self, or_default: bool = False) -> ProtoGraph | None: - result = self._apply() - if debug_asserts_enabled(): - result_to_check = result or self.protograph - for i in range(self._idempotent_repeats): - assert self.__class__(proto_graph=result_to_check, **self._kwargs)._apply() is None, (i, self) - return result or (self.protograph if or_default else None) - - @staticmethod - def chain(proto_graph: ProtoGraph, *transforms: type[ProtoGraphTransform]) -> ProtoGraph | None: - initial = proto_graph - for transform in transforms: - proto_graph = transform(proto_graph).apply(or_default=True) - return None if proto_graph is initial else proto_graph - - -class ReplaceProtoBlocks(ProtoGraphTransform): - """Helper to replace individual ProtoBlocks while retaining the same ProtoGraph topology.""" - - @abc.abstractmethod - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - ... - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - pass - - def _apply(self) -> ProtoGraph | None: - # TODO(robieta): Right now block order is load bearing, so we have to preserve it. - transformed = {i: self.apply_to_protoblock(i) for i in self.protograph} - - if any(transformed.values()): - replacements = {i: ProtoBlock(flow or i.flow) for i, flow in transformed.items()} - for old_protoblock, new_protoblock in replacements.items(): - self.post_apply(old_protoblock, new_protoblock) - for old_target, is_jump in old_protoblock.jump_targets: - new_protoblock.add_jump_target(replacements[old_target], is_jump) - return ProtoGraph(replacements.values(), provenance=(self.__class__, self.protograph)) - return None - - -class ReplaceValues(ReplaceProtoBlocks): - """Copies the ProtoGraph with value replacements. - - NOTE: This is strictly a condensing transform, and this is only invertible - (using another `ReplaceValues`) in trivial cases. - """ - - _retain_uses: bool = True - - @abc.abstractproperty - def replace_map(self) -> values.ReplaceMap: - ... - - @functools.cached_property - def _replace_map(self) -> values.ReplaceMap: - replace_map = self.replace_map - assert not (invalid := [k for k in replace_map if isinstance(k, values.NonPyObject)]), invalid - return FrozenDict(algorithms.flatten_map(replace_map)) - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - return protoblock.flow.substitute(self._replace_map) - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - if self._retain_uses: - new.uses.update(old.uses) - - -class CondenseValues(ReplaceValues): - ValueEdges = Iterable[tuple[values.AbstractValue, values.AbstractValue]] - - @abc.abstractproperty - def edges(self) -> ValueEdges: - ... - - @property - def replace_map(self) -> values.ReplaceMap: - replace_map: dict[values.AbstractValue, values.AbstractValue] = {} - edges = itertools.chain(self.edges, self._phivalue_constituent_edges) - for v, condensed in algorithms.compute_condense_map(edges).items(): - # Check invariants. - assert condensed - if not isinstance(v, values.AbstractPhiValue): - invariants = ({c.identity for c in condensed} == {v.identity}, not isinstance(v, values.AbstractRef)) - assert all(invariants) or not any(invariants), (invariants, v, condensed) - - # `AbstractPhiValue._unpack_apply` will determine if we need an AbstractPhiValue. - if (replacement := values.substitute_value(values.AbstractPhiValue(tuple(condensed)), {})) != v: - replace_map[v] = replacement - - return FrozenDict(replace_map) - - @property - def _phivalue_constituent_edges(self) -> ValueEdges: - # AbstractPhiValues are somewhat unusual in that mismatches between blocks - # are expected (that's sort of the point...) so we need to decompose them - # so the condense pass doesn't get tripped up. - for _, initial_ref in self.protograph.root.flow.begin_state: - if isinstance(initial_ref, values.AbstractPhiValue): - yield from ((constituent, initial_ref) for constituent in initial_ref.constituents) - - -class ReplaceSymbolic(ReplaceProtoBlocks): - _forbid_linked = True - - @abc.abstractmethod - def apply_to_symbolic( - self, - instruction: parse.ThunderInstruction, - symbolic: values.Symbolic, - inputs: values.HybridMap[values.AbstractValue], - ) -> values.Symbolic | None: - ... - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - flow_state = values.DigestFlow(protoblock.flow._begin) - updated_symbolic: dict[parse.ThunderInstruction, values.Symbolic | None] = {} - for i, symbolic in protoblock.flow.symbolic: - updated_symbolic[i] = self.apply_to_symbolic(i, symbolic, symbolic.inputs.map(flow_state.get)) - _ = flow_state.next(i, updated_symbolic[i] or symbolic) - - if any(updated_symbolic.values()): - new_symbolic = {k: v or protoblock.flow._symbolic[k] for k, v in updated_symbolic.items()} - return dataclasses.replace(protoblock.flow, _symbolic=FrozenDict(new_symbolic)) - return None - - -# ============================================================================= -# == Graph transforms (Applied) =============================================== -# ============================================================================= -class Unlink(ReplaceProtoBlocks): - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - if protoblock is not self.protograph.root: - uses = (flow := protoblock.flow).uses.copy() - end: values.Symbolic.EndState = FrozenDict({k: v for k, v in flow._end.items() if k != v}) - uses.update(v for v in end.values() if isinstance(v, parse.VariableKey) and not v.is_const) - any_non_ref = any(not isinstance(i, values.AbstractRef) for i in flow._begin.values()) - if any_non_ref or len(end) < len(flow._end) or flow._begin.keys() ^ uses: # symmetric_difference - begin: FrozenDict[parse.VariableKey, values.AbstractValue] - begin = FrozenDict({k: values.AbstractRef(f"Unlink: {k}") for k in uses}) - return dataclasses.replace(protoblock.flow, _begin=begin, _end=end) - - return None - - def _apply(self) -> ProtoGraph | None: - result = super()._apply() - assert len(result or self.protograph) == 1 or not (result or self.protograph).is_linked, result - return result - - -class AddTransitive(ReplaceProtoBlocks): - """Extend abstract value flows to include those needed by downstream blocks. - This pass effectively functionalizes the abstract value flow by plumbing - reads through parents as transitive dependencies. Note that we assume - variables are only modified by `STORE_...` and `DELETE_...` instructions. - This is not a sound assumption since opaque calls (`CALL_FUNCTION`, - `CALL_METHOD`, etc.) could mutate global and nonlocal variables. This does - not, however, pose an overall soundness problem because we can check for - state mutations during inlining and rerun flow analysis. - """ - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - flow = protoblock.flow - end = FrozenDict({**{use: use for use in self.target_uses(protoblock, self.expanded_uses)}, **flow._end}) - if (missing := self.expanded_uses[protoblock].difference(protoblock.uses)) or (end != flow._end): - begin = {**{k: values.AbstractRef("Transitive") for k in missing}, **flow._begin} - return dataclasses.replace(flow, _begin=FrozenDict(begin), _end=FrozenDict(end)) - return None - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - new.uses.update(self.expanded_uses[old]) - - @functools.cached_property - def expanded_uses(self) -> Mapping[ProtoBlock, Uses]: - """Identify new transitive value dependencies. - The process is more involved than simply checking for mismatches because - adding a transitive value to a block may necessitate adding a transitive - value to the prior block and so on. - """ - uses = {protoblock: protoblock.uses.copy() for protoblock in self.protograph} - blocks_to_process = collections.deque(uses.keys()) - - while blocks_to_process: - protoblock = blocks_to_process.popleft() - target_uses = self.target_uses(protoblock, uses) - - # The reason we can ignore ALL `_OutputRef`s (including those that would index into a composite) - # is that the (potential) composite's dependencies are already handled by `ProtoBlock._flow_uses`. - transitive_uses = OrderedSet( - source - for use in target_uses - if isinstance(source := protoblock.flow._end.get(use, use), parse.VariableKey) - and source.scope != parse.VariableScope.CONST - ) - - if transitive_uses - uses[protoblock]: - uses[protoblock].update(transitive_uses) - blocks_to_process.extend(self.protograph.parents[protoblock]) - - return FrozenDict(uses) - - @staticmethod - def target_uses(protoblock: ProtoBlock, uses: Mapping[ProtoBlock, Uses] = FrozenDict()) -> Uses: - flat_uses = itertools.chain(*(uses.get(target, target.uses) for target, _ in protoblock.jump_targets)) - return Uses(OrderedSet(use for use in flat_uses if use.scope != parse.VariableScope.CONST)) - - -class MatchStacks(ReplaceProtoBlocks): - """Ensure stacks match across blocks. - - ProtoGraph doesn't rely on stack behavior (push, pop TOS, etc.), however it - is still a good sanity check. (Which is why `Connect._inter_block_edges` asserts.) - """ - - def apply_to_protoblock(self, protoblock: ProtoBlock) -> values.IntraBlockFlow | None: - upstream = OrderedSet[parse.VariableKey]() - for parent in self.protograph.parents[protoblock]: - upstream.update(k for k in parent.flow._end if k.scope == parse.VariableScope.STACK) - - if delta := upstream - protoblock.flow._begin: - begin = {**protoblock.flow._begin, **{k: values.AbstractRef(f"Match stack: {k}") for k in delta}} - return dataclasses.replace(protoblock.flow, _begin=FrozenDict(begin)) - return None - - def post_apply(self, old: ProtoBlock, new: ProtoBlock) -> None: - new.uses.update(old.uses) - - -class Connect(CondenseValues): - @property - def edges(self) -> CondenseValues.ValueEdges: - yield from self._inter_block_edges(self.protograph) - yield from self._graph_input_edges(self.protograph) - - @staticmethod - def _graph_input_edges(proto_graph: ProtoGraph) -> CondenseValues.ValueEdges: - for key, initial_ref in proto_graph.root.flow.begin_state: - if isinstance(initial_ref.identity, values.ExternalRef): - continue - - assert isinstance(initial_ref, values.AbstractRef), initial_ref - assert key.scope not in ( - parse.VariableScope.CONST, - parse.VariableScope.STACK, - ), (key, proto_graph.root.flow._begin) - yield values.CompositeValue().add_identity(values.ExternalRef(key)), initial_ref - - @staticmethod - def _inter_block_edges(proto_graph: ProtoGraph) -> CondenseValues.ValueEdges: - for protoblock in proto_graph: - for child, _ in protoblock.jump_targets: - outputs = dict(protoblock.flow.end_state) - child_inputs = dict(child.flow.begin_state) - for key, child_input in child_inputs.items(): - yield outputs.get(key, values.NonPyObject(values.NonPyObject.Tag.MISSING)), child_input - - # `AddTransitive` should ensure the stacks match. - # (Except for return blocks which may discard the stack.) - opname = tuple(child.flow.symbolic)[-1][0].opname - if opname not in parse.RAISE_RETURN_INSTRUCTIONS: - s_out = tuple(sorted(i.identifier for i in outputs if i.scope == parse.VariableScope.STACK)) - s_in = tuple(sorted(i.identifier for i in child_inputs if i.scope == parse.VariableScope.STACK)) - assert s_out == s_in, f"{s_out=} != {s_in=}, {opname}" diff --git a/thunder/core/script/protograph_passes.py b/thunder/core/script/protograph_passes.py deleted file mode 100644 index 657d6d6c8f..0000000000 --- a/thunder/core/script/protograph_passes.py +++ /dev/null @@ -1,74 +0,0 @@ -import dataclasses -from collections.abc import Iterable - -from thunder.core.script import parse, values -from thunder.core.script.protograph import ProtoGraph, ProtoGraphTransform, AddTransitive, ReplaceSymbolic -from thunder.core.utils import debug_asserts_enabled - -ValueEdges = Iterable[tuple[values.AbstractValue, values.AbstractValue]] -KNOWN_TUPLE = values.TraitName("__known_tuple") - - -def _connect_protograph(proto_graph: "ProtoGraph") -> "ProtoGraph": - proto_graph = proto_graph.link() - assert AddTransitive(proto_graph).apply(or_default=False) is None - for protoblock in proto_graph: - for k, v in protoblock.flow.begin_state: - assert not v.is_detail, (k, v) - return proto_graph - - -class MarkTuples(ReplaceSymbolic): - def apply_to_symbolic( - self, - instruction: parse.ThunderInstruction, - symbolic: values.Symbolic, - _: values.HybridMap[values.AbstractValue], - ) -> values.Symbolic | None: - if instruction.opname == "BUILD_TUPLE" and isinstance(output := symbolic.outputs[0], values.IntermediateValue): - assert len(symbolic.outputs) == 1, symbolic.outputs - ordered = tuple(values.Reference(i) for i in range(-len(symbolic.inputs.ordered), 0)) - new_output = values.CompositeRef(ordered=ordered).add_named(KNOWN_TUPLE, values.ConstRef(True)) - return dataclasses.replace(symbolic, outputs=(new_output.add_identity(output),)) - return None - - -class IndexTuples(ReplaceSymbolic): - def apply_to_symbolic( - self, - instruction: parse.ThunderInstruction, - symbolic: values.Symbolic, - inputs: values.HybridMap[values.AbstractValue], - ) -> values.Symbolic | None: - replacement: values.Symbolic | None = None - if instruction.opname == "BINARY_SUBSCR": - to_index, index = inputs.ordered - is_tuple = isinstance(to_index, values.CompositeValue) and to_index.get(KNOWN_TUPLE) - index_key = index.key if isinstance(index, values.ExternalRef) and index.key.is_const else None - if is_tuple and index_key and isinstance(idx := index_key.identifier, int): - assert len(symbolic.outputs) == 1 - replacement = dataclasses.replace(symbolic, outputs=((values.Reference(0), values.Reference(idx)),)) - - elif instruction.opname == "UNPACK_SEQUENCE": - (to_unpack,) = inputs.ordered - if isinstance(to_unpack, values.CompositeValue) and to_unpack.get(KNOWN_TUPLE): - indices = (values.Reference(idx) for idx in range(-1, -len(symbolic.outputs) - 1, -1)) - outputs = tuple((values.Reference(0), idx) for idx in indices) - replacement = dataclasses.replace(symbolic, outputs=outputs) - - elif instruction.opname == "UNPACK_EX": - pass # TODO(apaz-cli): figure out indexing. - - return replacement if (replacement and replacement.outputs != symbolic.outputs) else None - - -def _tuple_fold(proto_graph: ProtoGraph) -> ProtoGraph: - """Replace tuple accesses (`BINARY_SUBSCR`, `UNPACK_SEQUENCE` instructions) with their members, if known.""" - return ProtoGraphTransform.chain(proto_graph, MarkTuples, IndexTuples) or proto_graph - - -def apply_protograph_passes(protograph: ProtoGraph) -> ProtoGraph: - protograph = _tuple_fold(protograph.unlink()) - protograph = _connect_protograph(protograph) - assert AddTransitive(protograph).apply(or_default=False) is None - return protograph diff --git a/thunder/core/script/python_ir.py b/thunder/core/script/python_ir.py deleted file mode 100644 index 2d66fbb77b..0000000000 --- a/thunder/core/script/python_ir.py +++ /dev/null @@ -1,507 +0,0 @@ -import collections -import dis -import inspect -import sys -import types -from typing import Any, Dict, List, Optional, Tuple, Union -from collections.abc import Callable -from collections.abc import Hashable - -from thunder.core.script.graph import ( - assert_block, - _generate_raises, - Graph, - GraphSummaryCallback, - MROAwareObjectRef, - Node, - SourceInformation, - Value, - insert_before, - insert_after, - _Undefined, -) -from thunder.core.script.parse import RETURN_VALUE -from thunder.core.script.python_ir_data import get_instruction, X_THUNDER_STORE_ATTR -from thunder.core.utils import OrderedSet - - -def undo_ssa(gr: "Graph") -> tuple[list[Value], list[str], list[str], list[Any]]: - consts: list[Any] = [] - names: list[str] = [] - - def get_value(v: Value, n: Node, inpidx: int | None = None) -> None: - if n.i.opname == "CALL_METHOD" and inpidx == 1: - bl = assert_block(n.block) - idx = bl.nodes.index(n) - if idx > 0 and bl.nodes[idx - 1].i.opname == "LOAD_METHOD": - # if we just a LOAD_METHOD, that did put input 0 and 1 on the stack - return - else: - # else the loading has been separated from the call, so we - # switch to call LOAD_ATTR/CALL_FUNCTION instead - n.i = n.i.modify_copy(opname="CALL_FUNCTION", opcode=None) - return - if isinstance(v.value, _Undefined): - idx = len(consts) - consts.append( - _generate_raises(f"attribute error '{type(v.value.value)}' object has no attribute '{v.value.attr}'") - ) - new_n = Node( - i=get_instruction(opname="LOAD_CONST", arg=idx), - outputs=[Value(value=consts[idx], is_const=True)], - inputs=[], - ) - new_n.inserted_for = n - insert_before(new_n, n) - new_n = Node(i=get_instruction(opname="CALL_FUNCTION", arg=0), outputs=[v], inputs=[consts[idx]]) - new_n.inserted_for = n - insert_before(new_n, n) - return - if v.is_const: - idx = len(consts) - consts.append(v.value) - new_n = Node(i=get_instruction(opname="LOAD_CONST", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - elif isinstance(v.value, MROAwareObjectRef): - # this works for attribs, but for methods? maybe have a pass eliminating/making explicit the super... - get_value(v.value.obj, n) - elif v.parent is not None: - assert v.name is not None - get_value(v.parent, n) - if n.i.opname == "CALL_METHOD" and inpidx == 0: - # print("###inputs", n.inputs, v, v in n.inputs) - try: - idx = names.index(v.name) - except ValueError: - idx = len(names) - names.append(v.name) - new_n = Node( - i=get_instruction(opname="LOAD_METHOD", arg=idx), - outputs=[v, v.parent], - inputs=[v.parent], - ) - new_n.inserted_for = n - insert_before(new_n, n) - elif n.i.opname == "LOAD_ATTR": - # print("###load attr", n.outputs, n.i.argval) - pass - else: - assert v.name is not None - try: - idx = names.index(v.name) - except ValueError: - idx = len(names) - names.append(v.name) - new_n = Node( - i=get_instruction(opname="LOAD_ATTR", arg=idx), - outputs=[v], - inputs=[v.parent], - ) - new_n.inserted_for = n - insert_before(new_n, n) - elif v.is_global and isinstance(v.value, Hashable) and v.value in __builtins__: - # Builtins are unmarshallable and meant to be loaded globally. If they are - # included in co_consts, the resulting function cannot go into a .pyc file. - # Originally, the plan was to check if the value is a builtin by checking - # if its type is "". However, this - # turned out not to work since torch for some reason decided to set the - # type of `torch.nn.functional.has_torch_function` to also be a builtin. - if v.name not in names: - names.append(v.name) - idx = names.index(v.name) - new_n = Node(i=get_instruction(opname="LOAD_GLOBAL", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - elif v.is_global: # make binding the globals optional? - if v.value not in consts: - consts.append(v.value) - idx = consts.index(v.value) - new_n = Node(i=get_instruction(opname="LOAD_CONST", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - else: - idx = local_vars[v] - # assert idx >= 0 - new_n = Node(i=get_instruction(opname="LOAD_FAST", arg=idx), outputs=[v], inputs=[]) - new_n.inserted_for = n - insert_before(new_n, n) - - for bl in gr.blocks: - for n in bl.nodes: - n.block = bl - - local_vars: dict[Value, int] = {} - lv_names: OrderedSet[str] = OrderedSet() - - def get_or_add_lv(v: Value, name: str | None = None) -> int: - idx = local_vars.get(v) - if idx is None: - idx = len(local_vars) - local_vars[v] = idx - - # handle name collisions... - if name is None: - name = v.name - - if name is None: - name = f"_tmp_{idx}" - else: - name = name.replace(".", "_").replace("[", "").replace("]", "") - - if not name[:1].isalpha(): - name = "_" + name - fullname = name - suffix = 0 - while fullname in lv_names: - suffix += 1 - fullname = f"{name}_{suffix}" - lv_names.add(fullname) - if v.name is None: # TODO: or do this always? - v.name = fullname - return idx - - nodes_to_skip = set() - - def store_phi_values(o: Value, o_idx: int, last_n: Node | None, cur_n: Node | None) -> Node | None: - phi_values_in_processing = set() - - def store_phi_values_inner(o: Value, o_idx: int, last_n: Node | None) -> Node | None: - if o in phi_values_in_processing: - # avoid loops - return last_n - phi_values_in_processing.add(o) - for v in o.phi_values: - # TODO: refactor into general mechanism - idx2 = get_or_add_lv(v) - # last_n = store_phi_values_inner(v, o_idx, last_n) - if o.is_const: - if o.value not in consts: - consts.append(o.value) - o_idx = consts.index(o.value) - new_n = Node(i=get_instruction(opname="LOAD_CONST", arg=o_idx), outputs=[o], inputs=[]) - new_n.inserted_for = cur_n - else: - new_n = Node(i=get_instruction(opname="LOAD_FAST", arg=o_idx), outputs=[o], inputs=[]) - new_n.inserted_for = cur_n - - nodes_to_skip.add(new_n) - if last_n is None: - insert_before(new_n, gr.blocks[0].nodes[0]) - else: - insert_after(new_n, last_n) - last_n = new_n - new_n = Node(i=get_instruction(opname="STORE_FAST", arg=idx2), outputs=[], inputs=[o]) - new_n.inserted_for = cur_n - nodes_to_skip.add(new_n) - insert_after(new_n, last_n) - last_n = new_n - return last_n - - return store_phi_values_inner(o, o_idx, last_n) - - for v in gr.local_variables_at_start: - if v is not None: - get_or_add_lv(v) - - # inputs in phi values - last_n = None - # need to make a copy of the list because we're adding items to the list - for idx, i in enumerate(tuple(local_vars.keys())): - last_n = store_phi_values(i, idx, last_n, cur_n=None) - for i in gr.blocks[0].block_inputs: # inlined parameters (partial) will be here - for v, js in zip(i.values, i.jump_sources): - if js is None and v.is_const: - last_n = store_phi_values(v, None, last_n, cur_n=None) - # print(i.values, i.jump_sources) - - names = [] - - for bl in gr.blocks: - jump_node = bl.nodes[-1] - for n in bl.nodes[:]: - processed_block_outputs = set() - if n not in nodes_to_skip: - n.inserted_for = n - for inpidx, i in enumerate(n.inputs): - get_value(i, n=n, inpidx=inpidx) - last_n = n - for o in n.outputs[::-1]: - idx = get_or_add_lv(o) - new_n = Node( - i=get_instruction(opname="STORE_FAST", arg=idx), - outputs=[], - inputs=[o], - ) - new_n.inserted_for = n - assert last_n is not None - insert_after(new_n, last_n) - last_n = new_n - if o in bl.block_outputs: - processed_block_outputs.add(o) - last_n = store_phi_values(o, idx, last_n, cur_n=n) - if n.i.opname in ("STORE_ATTR", "IMPORT_NAME"): # STORE_ATTR for unknown objs - # have a utility for this? - try: - idx = names.index(n.i.argval) - except ValueError: - idx = len(names) - names.append(n.i.argval) - n.i = n.i.modify_copy(arg=idx) - if n.i.opname == X_THUNDER_STORE_ATTR: - bl.nodes.remove(n) - if bl.nodes[-1].i.opname != RETURN_VALUE: # TODO Should the return block have outputs (probably not) - for o in bl.block_outputs: - if o not in processed_block_outputs: - get_value(o, n=jump_node) # before the jump - idx = get_or_add_lv(o, name="bo") - new_n = Node( - i=get_instruction(opname="STORE_FAST", arg=idx), - outputs=[], - inputs=[o], - ) - new_n.inserted_for = jump_node - insert_before(new_n, n=jump_node) - store_phi_values(o, idx, new_n, cur_n=jump_node) - - return list(local_vars.keys()), list(lv_names), names, consts - - -# this function is taken from PyTorch Dynamo (c) 2022 by Facebook/Meta licensed -# as per https://github.com/pytorch/pytorch/blob/master/LICENSE -def linetable_writer(first_lineno: int) -> tuple[list[int], Callable, Callable]: - """Used to create typing.CodeType.co_linetable See - https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt This is the internal format of the line number - table if Python >= 3.10.""" - assert sys.version_info >= (3, 9) - linetable: list[int] = [] - lineno = first_lineno - lineno_delta = 0 - byteno = 0 - - def _update(byteno_delta: int, lineno_delta: int) -> None: - while byteno_delta != 0 or lineno_delta != 0: - byte_offset = max(0, min(byteno_delta, 254)) - line_offset = max(-127, min(lineno_delta, 127)) - assert byte_offset != 0 or line_offset != 0 - byteno_delta -= byte_offset - lineno_delta -= line_offset - linetable.extend((byte_offset, line_offset & 0xFF)) - - def update(lineno_new: int, byteno_new: int) -> None: - nonlocal lineno, lineno_delta, byteno - byteno_delta = byteno_new - byteno - byteno = byteno_new - _update(byteno_delta, lineno_delta) - lineno_delta = lineno_new - lineno - lineno = lineno_new - - def end(total_bytes: int) -> None: - _update(total_bytes - byteno, lineno_delta) - - return linetable, update, end - - -def generate_function(gr: "Graph") -> Callable: - orig_gr = gr - gr, map_from_orig = gr.clone() - - local_vars, lv_names, names, consts = undo_ssa(gr) - assert len(local_vars) == len(lv_names) - - NodeKey = Union[Node, tuple[Node, bool]] - instruction_sizes: dict[NodeKey, int] = {} - - def build_address_map(end=False) -> dict[NodeKey, int]: - # Key either (for jump nodes and jump=True) - # or (, False) for non-jump in conditional jump - address_map: dict[NodeKey, int] = {} - ctr = 0 - for bl in gr.blocks: - # assumes first block is function start - for n in bl.nodes: - address_map[n] = ctr - ctr += instruction_sizes.get(n, 1) - if len(n.jump_targets) == 2: # implicit unconditional jump - ctr += instruction_sizes.get((n, False), 1) - if end: - address_map[n] = ctr - 1 - return address_map - - def make_bc() -> tuple[list[int], bool]: - bc = [] - - def write_extended_args(node_key: NodeKey, arg: int) -> bool: - # returns if instruction size has changed - instruction_size = instruction_sizes.get(node_key, 1) - if arg > 0x_FF_FF_FF or instruction_size == 4: - instruction_size = 4 - bc.append(dis.opmap["EXTENDED_ARG"]) - bc.append(arg >> 24) - if arg > 0x_FF_FF or instruction_size >= 3: - instruction_size = max(instruction_size, 3) - bc.append(dis.opmap["EXTENDED_ARG"]) - bc.append((arg >> 16) & 0xFF) - if arg > 0x_FF or instruction_size >= 2: - instruction_size = max(instruction_size, 2) - bc.append(dis.opmap["EXTENDED_ARG"]) - bc.append((arg >> 8) & 0xFF) - else: - instruction_size = 1 - - if instruction_size != instruction_sizes.get(node_key, 1): - instruction_sizes[node_key] = instruction_size - return True - return False - - changed_size = False - line_no = None - for bl in gr.blocks: - jump_node = None - for n in bl.nodes: - opcode = n.i.opcode - if opcode is None or opcode == -1: # Todo: opcode is typed int in ThunderInstruction, remove None here? - opcode = dis.opmap[n.i.opname] - assert opcode is not None, f"{n} has invalid opcode" - # source range instead for 3.11? - n_line_no = n.source_infos[-1].gen_line_no if n.source_infos else None - if n_line_no is not None and n_line_no != line_no: # really, the last generated one... - linetable_update( - n_line_no + gr.source_start_line, address_map[n] * 2 - ) # byte offset for Python 3.10, too... - line_no = n_line_no - if opcode in dis.hasjabs: - arg = address_map[n.jump_targets[-1].nodes[0]] - elif opcode in dis.hasjrel: - # TODO forward, backward - arg = address_map[n.jump_targets[-1].nodes[0]] - address_map[n] - 1 - else: - arg_ = n.i.arg - arg = 0 if arg_ is None else arg_ - - changed_size |= write_extended_args(n, arg) - - bc.append(opcode) - bc.append(arg & 0x_FF) - if len(n.jump_targets) > 1: - jump_node = n - if jump_node is not None: - assert len(jump_node.jump_targets) == 2 - jarg = address_map[jump_node.jump_targets[0].nodes[0]] - changed_size |= write_extended_args((jump_node, False), jarg) - i = get_instruction(opname="JUMP_ABSOLUTE", arg=jarg & 0xFF) - bc.append(i.opcode) - assert i.arg is not None - bc.append(i.arg) - return bc, not changed_size - - done = False - while not done: - linetable, linetable_update, linetable_end = linetable_writer(gr.source_start_line) - address_map = build_address_map() - bc, done = make_bc() - - inserted_for = collections.defaultdict(list) - end_address_map = build_address_map(end=True) - for n in gr.nodes(): - inserted_for[getattr(n, "inserted_for", None)].append(end_address_map[n]) - for n in orig_gr.nodes(): - info = inserted_for[map_from_orig[n]] - n.bytecode_range = (min(info), max(info)) if info else (None, None) - - linetable_end(len(bc)) - linetable_bytes = bytes(linetable) - bc_bytes = bytes(bc) - - lv_at_start = [v for v in gr.local_variables_at_start if v is not None] - co_argcount = gr.co_argcount - co_posonlyargcount = gr.co_posonlyargcount - co_kwonlyargcount = gr.co_kwonlyargcount - co_nlocals = len(local_vars) - # TODO: actually track the stack size when doing codegen (for optimizations) - co_stacksize = max(max(len(n.inputs), len(n.outputs)) for n in gr.nodes()) - co_flags = gr.co_flags - co_codestring = bc_bytes - co_consts = tuple(consts) - co_names = tuple(names) - co_varnames = tuple(lv_names) - co_filename = f"" - co_name = gr.co_name - co_firstlineno = gr.source_start_line - co_linetable = linetable_bytes # XXX - co_freevars = () - co_cellvars = () - - c = types.CodeType( - co_argcount, # int - co_posonlyargcount, # int - co_kwonlyargcount, # int - co_nlocals, # int - co_stacksize, # int - co_flags, # int - co_codestring, # bytes - co_consts, # tuple - co_names, # tuple - co_varnames, # tuple - co_filename, # string - co_name, # string - co_firstlineno, # integer - co_linetable, # bytes - co_freevars, # tuple - co_cellvars, # tuple - ) - - # types.FunctionType(code, globals, name=None, argdefs=None, closure=None) - func = types.FunctionType( - c, - { - "__builtins__": __builtins__, - }, - argdefs=tuple(gr.func_defaults), - ) - func.__kwdefaults__ = gr.func_kwdefaults - func._gr = orig_gr - - # simple cache hack - mtime = None # this signals that the cache should not be invalidated(!) - lines = gr.source_lines - size = len("".join(lines)) - inspect.linecache.cache[co_filename] = size, mtime, lines, co_filename - - try: - _ = tuple(dis.get_instructions(func)) - except BaseException as e: - raise RuntimeError("Unknown error generating callable") from e - - return func - - -def annotated_dis(thunder_fn, print_lines=True): - instructions = list(dis.get_instructions(thunder_fn)) - cur_pos = 0 - - class Callback(GraphSummaryCallback): - def node(self, n): - nonlocal cur_pos - before = [] - after = [] - begin_offset, end_offset = n.bytecode_range - if begin_offset is not None: - # the * 2 here and below is from Instruction.offset containing byte offsets and each - # bytecode is 2 bytes (this is true at least for Python 3.8-3.12) - while cur_pos < len(instructions) and instructions[cur_pos].offset < begin_offset * 2: - before.append(instructions[cur_pos]._disassemble()) - cur_pos += 1 - if end_offset is not None: - while cur_pos < len(instructions) and instructions[cur_pos].offset <= end_offset * 2: - after.append(instructions[cur_pos]._disassemble()) - cur_pos += 1 - return before, after - - def finish(self): - nonlocal cur_pos - l = [i._disassemble() for i in instructions[cur_pos:]] - cur_pos = len(instructions) - return l - - return thunder_fn._gr.summary(print_lines=print_lines, callback=Callback()) diff --git a/thunder/core/script/python_ir_data.py b/thunder/core/script/python_ir_data.py deleted file mode 100644 index c82eb168aa..0000000000 --- a/thunder/core/script/python_ir_data.py +++ /dev/null @@ -1,68 +0,0 @@ -import functools -import sys -from types import CodeType -from typing import Union -from collections.abc import Callable -from collections.abc import Iterable - -from thunder.core.script import parse - - -SUPPORTS_PREPROCESSING = (3, 9) <= sys.version_info < (3, 11) -X_THUNDER_STORE_ATTR = "X_THUNDER_STORE_ATTR" - - -# TODO(robieta): replace callsites. -get_instruction = functools.partial(parse.ThunderInstruction.make, line_no=-1) - - -def debug_compare_functions_print(diffs: dict[str, tuple[list, list]]): - for k, (v1, v2) in diffs.items(): - if not (v1 is None and v2 is None): - print(f"Differences in: {k}") - print(f" CodeObject 1: {v1}") - print(f" CodeObject 2: {v2}") - - -def debug_compare_functions( - code1: CodeType | Callable, code2: CodeType | Callable, *, show=False -) -> dict[str, tuple[list, list]]: - if not isinstance(code1, CodeType): - code1 = code1.__code__ - if not isinstance(code2, CodeType): - code2 = code2.__code__ - - attrs = [ - "co_argcount", - "co_kwonlyargcount", - "co_nlocals", - "co_stacksize", - "co_flags", - "co_consts", - "co_names", - "co_varnames", - "co_filename", - "co_name", - "co_freevars", - "co_cellvars", - ] - - diffs = {} - for attr in attrs: - v1 = getattr(code1, attr) - v2 = getattr(code2, attr) - - if v1 != v2: - if isinstance(v1, dict) and isinstance(v2, dict): - diffs[attr] = (v1 - v2, v2 - v1) - if isinstance(v1, str) and isinstance(v2, str): - diffs[attr] = (v1, v2) - elif isinstance(v1, Iterable) and isinstance(v2, Iterable): - diffs[attr] = (set(v1) - set(v2), set(v2) - set(v1)) - else: - diffs[attr] = (v1, v2) - - if show: - debug_compare_functions_print(diffs) - - return diffs diff --git a/thunder/core/script/values/__init__.py b/thunder/core/script/values/__init__.py deleted file mode 100644 index f1361cd857..0000000000 --- a/thunder/core/script/values/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from thunder.core.script.values.base import * -from thunder.core.script.values.composite import * -from thunder.core.script.values.materialization import * -from thunder.core.script.values.symbolic import * diff --git a/thunder/core/script/values/base.py b/thunder/core/script/values/base.py deleted file mode 100644 index 1956c31050..0000000000 --- a/thunder/core/script/values/base.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import annotations - -import dataclasses -import enum -import textwrap -from typing import overload, Any, Generic, NewType, TypeVar -from collections.abc import Callable -from collections.abc import Mapping - -from typing_extensions import Self - -from thunder.core.script import parse -from thunder.core.utils import FrozenDict - -__all__ = ( - # HybridMap - "Reference", - "TraitName", - "HybridMap", - # - # Values - "AbstractValue", - "AbstractRef", - "NonPyObject", - "IntermediateValue", - "ExternalRef", - # - # Substitution - "substitute_value", - "ReplaceMap", -) - - -# ============================================================================= -# == Generic (hybrid tuple/dict) container ==================================== -# ============================================================================= -T = TypeVar("T") -T1 = TypeVar("T1") -Reference = NewType("Reference", int) -TraitName = NewType("TraitName", str) - - -@dataclasses.dataclass(frozen=True, eq=True) -class HybridMap(Generic[T]): - ordered: tuple[T, ...] = dataclasses.field(kw_only=True, default_factory=tuple) - named: FrozenDict[TraitName, T] = dataclasses.field(kw_only=True, default_factory=FrozenDict) - - def __getitem__(self, key: Reference | TraitName) -> T: - if isinstance(key, int): - return self.ordered[key] - elif isinstance(key, str): - return self.named[key] - raise TypeError(f"Invalid key: {key}") - - def __repr__(self) -> str: - parts = [f"{self.__class__.__name__}("] - if self.ordered: - ordered = "\n".join(repr(i) for i in self.ordered) - parts.append(f" ordered:\n{textwrap.indent(ordered, ' ' * 4)}") - - if self.named: - named = "\n".join(f"{k}: {v}" for k, v in self.named.items()) - parts.append(f" named:\n{textwrap.indent(named, ' ' * 4)}") - - return "\n".join((*parts, ")")) - - @overload - def map(self, f: Callable[[T], T]) -> Self: - ... - - @overload - def map(self, f: Callable[[T], T1]) -> HybridMap[T1]: - ... - - def map(self, f: Any) -> Any: - ordered = tuple(f(i) for i in self.ordered) - named: FrozenDict[TraitName, T] = FrozenDict({k: f(v) for k, v in self.named.items()}) - return dataclasses.replace(self, ordered=ordered, named=named) - - def get(self, name: TraitName) -> T | None: - return self.named.get(name) - - def add_named(self, name: TraitName, value: T) -> Self: - named = dict(self.named) - named.update({name: value}) # Preserve order. - return dataclasses.replace(self, named=FrozenDict(named)) - - -# ============================================================================= -# == Simple value types ======================================================= -# ============================================================================= -class AbstractValue: - """Represents a value during instruction parsing. (Prior to type binding.)""" - - __is_detail = True - - def __init_subclass__(cls, **kwargs: Any) -> None: - cls.__is_detail = kwargs.pop("__is_detail", False) # `dataclasses` forces this into kwargs for some reason. - super().__init_subclass__(**kwargs) - - def __copy__(self) -> AbstractValue: - raise NotImplementedError - - @property - def is_detail(self) -> bool: - return self.__is_detail - - @property - def identity(self) -> AbstractValue: - """Analogous to `id(obj)`. For composites there is a layer of state management above the value itself. - - This is not suitable for equality checks (for example, mutation does not change - an object's identity), but it is often the appropriate target for `isinstance` checks. - """ - return self - - def _unpack_apply(self, _: ReplaceMap) -> AbstractValue: - """Recursively update any constituent references in the abstract value.""" - return self - - -@dataclasses.dataclass(frozen=True, eq=False) -class AbstractRef(AbstractValue, __is_detail=True): - """Placeholder value which will be resolved during parsing.""" - - _debug_info: str = "N/A" - - -@dataclasses.dataclass(frozen=True, eq=True) -class NonPyObject(AbstractValue): - """Singleton values used to signal some special interpreter state.""" - - class Tag(enum.Enum): - DELETED = enum.auto() - MISSING = enum.auto() - NULL = enum.auto() - - tag: Tag - - def __repr__(self) -> str: - return self.tag.name - - -class IntermediateValue(AbstractValue): - """A (potentially) new value produced by an instruction.""" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(at {hex(id(self))})" - - -@dataclasses.dataclass(frozen=True, eq=True) -class ExternalRef(AbstractValue): - """Reference values outside of the parsed code. (Arguments, constants, globals, etc.)""" - - key: parse.VariableKey - - def __repr__(self) -> str: - if self.key.is_const: - return f"Const({self.key.identifier})" - return f"{self.__class__.__name__}({self.key.identifier}, {self.key.scope.name})" - - -# ============================================================================= -# == Value substitution ======================================================= -# ============================================================================= -ReplaceMap = Mapping["AbstractValue", "AbstractValue"] - - -@overload -def substitute_value(v: AbstractValue, replace_map: ReplaceMap) -> AbstractValue: - ... - - -@overload -def substitute_value(v: T, replace_map: ReplaceMap) -> T: - ... - - -def substitute_value(v: Any, replace_map: ReplaceMap) -> Any: - """Find the replacement for `v`, and recursively substitute. (If applicable.) - - Some abstract values reference other abstract values. When we make substitution during - graph transformations it is necessary to also consider replacement of an abstract - value's constituents. Any subclass which must be unpacked in this manner should - override `_unpack_apply`. - """ - if not isinstance(v, AbstractValue): - return v - - new_v = replace_map.get(v, v) - if new_v != (x := replace_map.get(new_v, new_v)): - msg = f""" - `replace_map` may not contain chains. - {v} - {new_v} - {x} - See `flatten_map`.""" - raise ValueError(textwrap.dedent(msg)) - - return new_v._unpack_apply(replace_map) diff --git a/thunder/core/script/values/composite.py b/thunder/core/script/values/composite.py deleted file mode 100644 index bc00dd5469..0000000000 --- a/thunder/core/script/values/composite.py +++ /dev/null @@ -1,177 +0,0 @@ -import abc -import dataclasses -import itertools -from typing import Any, TypeVar -from collections.abc import Iterable - -from typing_extensions import Self - -from thunder.core.script import parse -from thunder.core.script.values import base, symbolic -from thunder.core.utils import FrozenDict - -__all__ = ("InternalRef", "OrderedSlice", "CompositeValue", "CompositeRef", "AbstractPhiValue") - -T = TypeVar("T") - - -# ============================================================================= -# == References =============================================================== -# ============================================================================= -class InternalRef(base.AbstractValue, abc.ABC): - @abc.abstractmethod - def _resolve(self, inputs: base.HybridMap[base.AbstractValue]) -> base.AbstractValue: - """Defines how to concretize itself.""" - ... - - @property - def is_detail(self) -> bool: - # All ref types are unsuitable for Graph binding. - return True - - @classmethod - def resolve( - cls, output: symbolic.Symbolic.Output, *, inputs: base.HybridMap[base.AbstractValue] | None = None - ) -> base.AbstractValue: - inputs = base.HybridMap() if inputs is None else inputs - if isinstance(output, (int, str)): - return inputs[output] - - elif isinstance(output, symbolic.ConstRef): - return base.ExternalRef(parse.VariableKey(output.identifier, parse.VariableScope.CONST)) - - if isinstance(output, tuple): - cls.validate_reference(output) - result = inputs[output[0]] - for idx in output[1:]: - # We can only unpack a (possibly nested) composite. - assert isinstance(result, base.HybridMap), result - result = result[idx] - - assert isinstance(result, base.AbstractValue) - return result - - elif isinstance(output, InternalRef): - return output._resolve(inputs) - - return output - - @staticmethod - def validate_reference(x: symbolic.NestedReference) -> None: - x = (x,) if isinstance(x, (int, str)) else x - assert isinstance(x, tuple) and x and all(isinstance(xi, (int, str)) for xi in x), x - - -# ============================================================================= -# == Nesting ================================================================== -# ============================================================================= -@dataclasses.dataclass(frozen=True, eq=True) -class OrderedSlice: - reference: symbolic.NestedReference - slice: slice - - def __hash__(self) -> int: - # `slice` isn't hashable until 3.12 - return hash((self.reference, self.slice.start, self.slice.stop, self.slice.step)) - - -@dataclasses.dataclass(frozen=True, eq=True, repr=False) -class _Composite(base.AbstractValue, base.HybridMap[T], __is_detail=True): - """Models an AbstractValue that references other (possibly also AbstractValue) state. - - Note: `ordered` and `named` should not contain cycles. - """ - - Identity = base.TraitName("__Thunder_Object_Identity") - - def _unpack_apply(self, replace_map: base.ReplaceMap) -> base.AbstractValue: - new_self = self.map(lambda x: base.substitute_value(x, replace_map)) - assert isinstance(new_self, _Composite) # For mypy since we can't hint `_Composite[T] -> _Composite[T1]` - return new_self - - def add_identity(self, identity: T) -> Self: - return self.add_named(self.Identity, identity) - - # NOTE: We don't override `identity`. (Instead retaining `return self` from AbstractValue.) - # This is because we generally won't know a good value, and passes should do their - # own type checking. (And that checking should almost always be done on the materialized - # value, not the symbolic reference.) - - -@dataclasses.dataclass(frozen=True, eq=True, repr=False) -class CompositeValue(_Composite[base.AbstractValue]): - def __post_init__(self) -> None: - assert all(isinstance(i, base.AbstractValue) for i in self.ordered) - assert all(isinstance(i, base.AbstractValue) for i in self.named.values()) - - @property - def is_detail(self) -> bool: - return any(i.is_detail for i in itertools.chain(self.ordered, self.named.values())) - - @property - def identity(self) -> base.AbstractValue: - return self.named.get(self.Identity, self) - - -@dataclasses.dataclass(frozen=True, eq=True) -class CompositeRef(InternalRef, _Composite[symbolic.Symbolic.Output | OrderedSlice]): - def __post_init__(self) -> None: - assert not any(isinstance(i, OrderedSlice) for i in self.named.values()) - - def _resolve(self, inputs: base.HybridMap[base.AbstractValue]) -> CompositeValue: - ordered: list[base.AbstractValue] = [] - for i in self.ordered: - if isinstance(i, OrderedSlice): - slice_target = self.resolve(i.reference, inputs=inputs) if i.reference else inputs - assert isinstance(slice_target, base.HybridMap) - ordered.extend(slice_target.ordered[i.slice]) - else: - ordered.append(self.resolve(i, inputs=inputs)) - - named: dict[base.TraitName, base.AbstractValue] = {} - for k, v in self.named.items(): - assert not isinstance(v, OrderedSlice) - named[k] = self.resolve(v, inputs=inputs) - - return CompositeValue(ordered=tuple(ordered), named=FrozenDict(named)) - - -# ============================================================================= -# == Unions =================================================================== -# ============================================================================= -@dataclasses.dataclass(frozen=True, eq=True) -class AbstractPhiValue(base.AbstractValue): - constituents: tuple[base.AbstractValue, ...] - - def __post_init__(self) -> None: - # Flatten nested PhiValues. e.g. - # 𝜙[𝜙[A, B], 𝜙[A, C]] -> 𝜙[A, B, C] - constituents = itertools.chain(*[self.flatten(i) for i in self.constituents]) - - # Ensure a consistent order. - constituents = tuple(v for _, v in sorted({hash(v): v for v in constituents}.items())) - assert not any(isinstance(i, InternalRef) for i in constituents) - object.__setattr__(self, "constituents", constituents) - - def __getitem__(self, _: Any) -> base.AbstractValue: - # The semantics of indexing into an `AbstractPhiValue`` are not well defined: - # - The order of `constituents` is arbitrary - # - It's unclear if the desire is to select one constituent or create a new `AbstractPhiValue` - # which indexes into each constituent. - # If a concrete use case emerges we can tackle it; until then we refuse for safety. - - # TODO(robieta): Handle traits - raise NotImplementedError - - def _unpack_apply(self, replace_map: base.ReplaceMap) -> base.AbstractValue: - result = AbstractPhiValue(tuple(base.substitute_value(v, replace_map) for v in self.constituents)) - return result if len(result.constituents) > 1 else result.constituents[0] - - @classmethod - def flatten(cls, v: base.AbstractValue) -> Iterable[base.AbstractValue]: - constituents = [cls.flatten(i) for i in v.constituents] if isinstance(v, AbstractPhiValue) else [[v]] - yield from itertools.chain(*constituents) - - @property - def is_detail(self) -> bool: - return any(i.is_detail for i in self.constituents) diff --git a/thunder/core/script/values/materialization.py b/thunder/core/script/values/materialization.py deleted file mode 100644 index e12e35d0e5..0000000000 --- a/thunder/core/script/values/materialization.py +++ /dev/null @@ -1,171 +0,0 @@ -from __future__ import annotations - -import dataclasses -import functools -import itertools -from types import MappingProxyType -from typing import Literal, TypeVar -from collections.abc import Callable, Iterator - -from thunder.core.script import parse -from thunder.core.script.values import base, composite, symbolic -from thunder.core.utils import FrozenDict, OrderedSet -from collections.abc import Iterable - -__all__ = ("Materialized", "DigestFlow", "IntraBlockFlow") -T = TypeVar("T") - - -# ============================================================================= -# == Intra-ProtoBlock abstract value flow ===================================== -# ============================================================================= -# -# `ProtoBlocks` employ a dual representation, where node inputs and outputs can -# be viewed as either a reference based DAG or a sequence of ops with concrete -# `AbstractValue` inputs and outputs. -# -# At the boundaries of a ProtoBlock values have named (VariableKey) slots; -# within the ProtoBlock there is no need for such slots (since there is no -# control flow within a block and those named slots tell you how to build the -# directed *cyclic* graph for the larger program) so they are stripped during -# parsing. -# -# The inputs of a protoblock are stored as a map of `VariableKey -> AbstractValue` -# and act as the intra-block DAG sources. The outputs are stored as references -# since every ProtoBlock output must have a unique producer. (Either an input -# or a node within the block.) -# -# The canonical representation for intra-block flow is "symbolic" (reference -# based). If an `AbstractValue` appear in a symbolic node's outputs that -# indicates that the node is that value's producer. Otherwise all inputs and -# outputs are references: inputs reference either the begin state or the -# outputs of a prior node while output references index into the node's inputs. -# -# When analyzing a graph we are generally interested in the concrete properties -# of values; provenance is generally only important when connecting blocks and -# performing rewrites. For these cases `IntraBlockFlow` generates a -# "materialized" flow which resolves all references to `AbstractValue`s. The -# symbolic representation is sufficient to emit the materialized representation, -# but the reverse is not true. -VarT = TypeVar("VarT", bound=base.AbstractValue, covariant=True) -ConcreteState = FrozenDict[parse.VariableKey, base.AbstractValue] -EndState = FrozenDict[parse.VariableKey, symbolic.Symbolic.Input] - - -@dataclasses.dataclass(frozen=True, eq=False) -class Materialized: - """Flow element where all symbolic references have been resolved to concrete `AbstractValue`s.""" - - inputs: base.HybridMap[base.AbstractValue] - outputs: tuple[base.AbstractValue, ...] - - -class DigestFlow: - """One-shot helper for materializing a block.""" - - GetT = Callable[[symbolic.Symbolic.Input], base.AbstractValue] - - def __init__(self, begin: ConcreteState) -> None: - self._begin = begin - self._result: dict[parse.ThunderInstruction, Materialized] = {} - - def next(self, instruction: parse.ThunderInstruction, symbolic: symbolic.Symbolic) -> Materialized: - """Lookup the materialized node corresponding to a symbolic node.""" - - # NB: `inputs_after_op` will be needed after we introduce mutations. - inputs = inputs_after_op = symbolic.inputs.map(self.get) - outputs = tuple(composite.InternalRef.resolve(o, inputs=inputs_after_op) for o in symbolic.outputs) - assert all(isinstance(o, base.AbstractValue) for o in outputs), outputs - - self._result[instruction] = result = Materialized(inputs, outputs) - return result - - def get(self, key: symbolic.Symbolic.Input) -> base.AbstractValue: - """Resolve a Symbolic input based on the block state at that node.""" - result: base.AbstractValue - if isinstance(key, base.NonPyObject.Tag): - result = base.NonPyObject(key) - - elif isinstance(key, symbolic.OutputRef): - inputs = base.HybridMap(ordered=self._result[key.instruction].outputs) - result = composite.InternalRef.resolve(key.idx, inputs=inputs) - - else: - assert isinstance(key, parse.VariableKey), key - result = base.ExternalRef(key) if key.is_const else self._begin[key] - return result - - -@dataclasses.dataclass(frozen=True, eq=False) -class IntraBlockFlow: - _symbolic: FrozenDict[parse.ThunderInstruction, symbolic.Symbolic] - _begin: ConcreteState - _end: EndState - - StateIterT = Iterator[tuple[parse.VariableKey, base.AbstractValue]] - - def __post_init__(self) -> None: - assert not (forbidden := tuple(i for i in self._symbolic if i in parse.FORBIDDEN_INSTRUCTIONS)), forbidden - object.__setattr__(self, "_symbolic", FrozenDict(self._symbolic)) - - missing = {i: base.AbstractRef("Inferred") for i in self.uses if i not in self._begin} - object.__setattr__(self, "_begin", FrozenDict({**missing, **self._begin})) - assert not any(k.is_const for k in self._begin), self._begin - assert not any(isinstance(v, composite.InternalRef) for v in self._begin.values()), self._begin - - object.__setattr__(self, "_end", FrozenDict(self._end)) - - @functools.cache - def __getitem__(self, key: tuple[symbolic.Symbolic.Input, Literal[0, 1]]) -> base.AbstractValue: - assert key[1] in (0, 1) - return self._computed[1:][key[1]](key[0]) - - @property - def symbolic(self) -> Iterable[tuple[parse.ThunderInstruction, symbolic.Symbolic]]: - yield from self._symbolic.items() - - @property - def materialized(self) -> FrozenDict[parse.ThunderInstruction, Materialized]: - return self._computed[0] - - @property - def begin_state(self) -> StateIterT: - yield from self._sort_and_filter_state(iter(self._begin.items())) - - @property - def end_state(self) -> StateIterT: - yield from self._sort_and_filter_state((k, self[v, 1]) for k, v in self._end.items()) - - @staticmethod - def _sort_and_filter_state(kv: StateIterT) -> StateIterT: - yield from ((k, v) for k, v in sorted(kv) if not isinstance(v, base.NonPyObject)) - - @property - def uses(self) -> OrderedSet[parse.VariableKey]: - assignment = (v for k, v in self._end.items() if isinstance(v, parse.VariableKey) and not v.is_const and k != v) - return OrderedSet(itertools.chain(*(s.uses for _, s in self.symbolic), assignment)) - - _Computed = tuple[ - FrozenDict[parse.ThunderInstruction, Materialized], - DigestFlow.GetT, # Begin - DigestFlow.GetT, # End - ] - - @functools.cached_property - def _computed(self) -> _Computed: - flow_state = DigestFlow(self._begin) - materialized_flow: FrozenDict[parse.ThunderInstruction, Materialized] - materialized_flow = FrozenDict({i: flow_state.next(i, s) for i, s in self.symbolic}) # Populates `flow_state` - return materialized_flow, DigestFlow(self._begin).get, flow_state.get - - def substitute(self, replace_map: base.ReplaceMap) -> IntraBlockFlow | None: - """Replace `AbstractValue`s within the flow. (Block inputs and producer nodes.)""" - replace_map_view = MappingProxyType(replace_map) - new_symbolic: FrozenDict[parse.ThunderInstruction, symbolic.Symbolic] - new_symbolic = FrozenDict({k: (s.substitute(replace_map_view)) for k, s in self.symbolic}) - begin = ConcreteState({k: base.substitute_value(v, replace_map_view) for k, v in self._begin.items()}) - - # TODO(robieta): Check if a value is only present in `materialized` and error. - if self._symbolic != new_symbolic or self._begin != begin: - return dataclasses.replace(self, _symbolic=new_symbolic, _begin=begin) - return None diff --git a/thunder/core/script/values/symbolic.py b/thunder/core/script/values/symbolic.py deleted file mode 100644 index 60bbaa08b6..0000000000 --- a/thunder/core/script/values/symbolic.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Introduce references inside simple blocks.""" -from __future__ import annotations - -import dataclasses -import itertools -import sys -from typing import Any, NamedTuple, TypeAlias -from collections.abc import Callable, Iterable - -from typing_extensions import Self - -from thunder.core.script import parse -from thunder.core.script.values import base -from thunder.core.utils import FrozenDict, safe_zip - -__all__ = ("OutputRef", "ParsedSymbolic", "Symbolic", "NestedReference", "ConstRef") - - -# ============================================================================= -# == Opcode-specific behavior ================================================= -# ============================================================================= -def rotate_N(oparg: int) -> tuple[int, ...]: - return (-1,) + tuple(range(-oparg, -1)) - - -_AliasMask = tuple[int | None, ...] -ALIAS_OPCODES = FrozenDict[str, _AliasMask | Callable[[int], _AliasMask]]( - parse.fill_ellipses( - # - # Stack manipulation - ROT_N=rotate_N, # A,B,...,Z -> Z,A,B,... - ROT_FOUR=rotate_N, - ROT_THREE=rotate_N, - ROT_TWO=rotate_N, - DUP_TOP=(-1, -1), # A -> A,A - DUP_TOP_TWO=(-2, -1) * 2, # A,B -> A,B,A,B - # - # Insertion leaves container on the stack A,B -> A - SET_ADD=(-2,), - SET_UPDATE=..., - LIST_APPEND=..., - LIST_EXTEND=..., - DICT_MERGE=..., - DICT_UPDATE=..., - MAP_ADD=(-3,), - COPY_DICT_WITHOUT_KEYS=(-2, None), # A,B -> A,C (I am unsure...) - # - # Misc. - GET_LEN=(-1, None), - MATCH_MAPPING=(-1, None), - MATCH_SEQUENCE=..., - MATCH_KEYS=(-1, -2, None) + () if sys.version_info >= (3, 11) else (None,), - # - # Jump dependent - FOR_ITER=(-1, None), - # NOTE: These instructions have been removed since they are extraneous special cases. - # https://github.com/faster-cpython/ideas/issues/567 - # https://github.com/python/cpython/issues/102859 - JUMP_IF_TRUE_OR_POP=(-1,), - JUMP_IF_FALSE_OR_POP=(-1,), - # - # This isn't actually correct. `LOAD_METHOD` will return either - # A -> B, A - # A -> B, NULL - # However the `A | NULL` is only consumed by `CALL_METHOD`, so it's ok to use this alias. - LOAD_METHOD=(None, -1), # A -> B,A - ) -) - - -# ============================================================================= -# == Symbolic flow ============================================================ -# ============================================================================= -IndexT: TypeAlias = base.Reference | base.TraitName -NestedReference: TypeAlias = IndexT | tuple[IndexT, ...] - - -@dataclasses.dataclass(frozen=True, eq=True) -class OutputRef: - """Identifies the producer of a value within a block.""" - - instruction: parse.ThunderInstruction # Acts as a key for the producer Flow. - idx: NestedReference # Indexes the producer's outputs. - - -class ConstRef(NamedTuple): - """Convenience wrapper to access `ExternalRef(VariableKey(..., CONST))` as a reference. - - This saves us from having to plumb constants through named inputs, since: - A) `ExternalRef`s cannot appear in Symbolic outputs. - (Since that implies a producer relationship which doesn't make sense.) - B) Symbolic reference outputs must reference an input, which would mean an entry of - `{some_random_name: VariableKey(..., CONST)}` would have to be added to named inputs - which is tedious. - """ - - identifier: Any - - -@dataclasses.dataclass(frozen=True, eq=True) -class Symbolic: - """Represents abstract flow immediately after functionalization.""" - - # VariableKey: References the value of that variable at the start of the block - # OutputRef: Reference values created by an earlier instruction within the block - # SingletonValue.Tag: Reserved for special cases. - Input = parse.VariableKey | OutputRef | base.NonPyObject.Tag - inputs: base.HybridMap[Input] - - # NestedReference: Aliases the input at this position. - # AbstractValue: New value created by this instruction - Output = NestedReference | ConstRef | base.AbstractValue - outputs: tuple[Output, ...] - - BeginState = FrozenDict[parse.VariableKey, base.AbstractValue] - EndState = FrozenDict[parse.VariableKey, Input] - Block = tuple[FrozenDict[parse.ThunderInstruction, "Symbolic"], BeginState, EndState] - - def __post_init__(self) -> None: - # If an `AbstractValue` appears in `Symbolic.outputs` that implies that the symbolic - # node in question is the value's producer. However it doesn't make sense for an external - # value to be produced within the compiled function. - assert not any(isinstance(o, base.ExternalRef) for o in self.outputs), self - - @property - def uses(self) -> Iterable[parse.VariableKey]: - """Block inputs used by this node. - - NOTE: This does not include values produced by an earlier node in the block. - """ - for i in itertools.chain(self.inputs.ordered, self.inputs.named.values()): - if isinstance(i, parse.VariableKey) and not i.is_const: - yield i - - def substitute(self, replace_map: base.ReplaceMap) -> Self: - outputs = tuple(base.substitute_value(o, replace_map) for o in self.outputs) - return dataclasses.replace(self, outputs=outputs) - - -# ============================================================================= -# == Conversion from functional representation ================================ -# ============================================================================= -@dataclasses.dataclass(frozen=True) -class ParsedSymbolic: - blocks: tuple[Symbolic.Block, ...] - provenance: parse.ParsedFunctional - - @classmethod - def make(cls, parsed: parse.ParsedFunctional) -> ParsedSymbolic: - blocks: list[Symbolic.Block] = [] - for block, begin_state, end_state in parsed.blocks: - # `functionalize_blocks` produces unique values, so provenance is unambiguous. - producers: dict[parse.PlaceholderValue | None, Symbolic.Input] = {v: k for k, v in begin_state.items()} - producers[None] = base.NonPyObject.Tag.DELETED - assert len(producers) == len(begin_state) + 1, (producers, end_state) - - symbolic_blocks: dict[parse.ThunderInstruction, Symbolic] = {} - for instruction, raw_inputs, raw_outputs in block: - for idx, o in enumerate(raw_outputs): - assert o not in producers - producers[o] = OutputRef(instruction, base.Reference(idx)) - - outputs: tuple[Symbolic.Output, ...] = tuple(base.IntermediateValue() for _ in raw_outputs) - if alias := ALIAS_OPCODES.get(instruction.opname): - mask = alias(len(outputs)) if callable(alias) else alias - mask = (base.Reference(i) if i is not None else i for i in mask) - outputs = tuple(o if o_mask is None else o_mask for o, o_mask in safe_zip(outputs, mask)) - inputs = base.HybridMap(ordered=tuple(producers[i] for i in raw_inputs)) - symbolic_blocks[instruction] = Symbolic(inputs, outputs) - - begin = {k: base.AbstractRef(v) for k, v in begin_state.items() if not k.is_const} - end = {k: producers[v] for k, v in end_state.items() if not k.is_const} - blocks.append((FrozenDict(symbolic_blocks), FrozenDict(begin), FrozenDict(end))) - - return cls(tuple(blocks), parsed) diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 354e7ddc18..cb22553aa0 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -109,10 +109,10 @@ def default_python_printer( # A symbol represents a function and how it can be transformed # name is a string name for the operation -# meta should use lightning.compile functions to evaluate the function; -# it will be called with lightning.compile proxies +# meta should use thunder.jit functions to evaluate the function; +# it will be called with thunder.jit proxies # id is an optional value to use when translating the function to executors -# is_prim should be True if the Symbol represents a lightning.compile primitive +# is_prim should be True if the Symbol represents a thunder.jit primitive # python_printer is a function that will produce valid Python for calling the # operation; this can usually be set to None, in which case the default python # printer will be used for the Symbol. Symbols that control their own printing @@ -196,14 +196,6 @@ def module(self) -> None | ModuleType: result = inspect.getmodule(fn_) return result - # Properties used in transforms (defined later) - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/326 - # Remove this from here (think how symbols could be extended with transforms) - # self.grad_defined = False - # self.grad_ignored = False - # self.grad_fwd = None - # self.grad_bwd = None - def __repr__(self) -> str: return f"[Symbol name={self.name}]" @@ -313,7 +305,6 @@ def __post_init__(self): # Constructs a new BoundSymbol with default values taken from this BoundSymbol # Override values can be specified as kwargs - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/680 # Issue -- Provide a pattern for updating subsymbols when swapping outputs # Maybe this can also just swap one set of symbols for another? # Consider adding verification that the new and old output have the same metadata diff --git a/thunder/core/trace.py b/thunder/core/trace.py index 4e2bed9eba..88f1950585 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -19,7 +19,7 @@ from thunder.core.codeutils import ContextObject -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/327 +# TODO see issue "Improve TraceProvenance" # Make this more interesting / printer better -- maybe let # practitioners acquire the pass callable so they can replicate the pass? # This class is intended to describe how the trace was constructed @@ -36,7 +36,7 @@ def __repr__(self) -> str: # TODO Should traces be BoundSymbols? -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/323 +# TODO issue "Create a mechanism for freezing TraceCtx objects" # Add validation that a constant is never assigned to / reassigned # Possibly separate the ideas of a trace -- a series of scopes containing bound symbols -- # and a TraceCtx, which can produce new traces @@ -303,7 +303,7 @@ def python_ctx(self) -> dict: return import_ctx # TODO Account for multi-line signatures - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/324 + # TODO issue "Add type annotations to Python function produced by traces" # Consider extending the signature with type information, in particular the # the type information of the return value might be interesting def python(self, *, print_depth: int = 1) -> str: @@ -395,7 +395,7 @@ def keyfn(class_or_module: type | ModuleType) -> str: reset_tracectx(token) # Returns a Python callable that executes the trace - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/323 + # TODO issue "Create a mechanism for freezing TraceCtx objects" # Create a mechanism for freezing traces and cache the compilation def python_callable(self, *, global_dicts: None | dict = None) -> Callable: python_str: str diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 7b602084b6..fa74313bfe 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -134,9 +134,9 @@ def replace_redundant_inputs( return new_bsyms -# TODO(crcrpar): Implement a mechanism to keep track of supported ops that cannot be CSE'd. -# For example, `uniform`, `dropout`, and `scaled_dot_product_attention`. -# See: https://github.com/Lightning-AI/lightning-thunder/issues/671 +# These are ops that are not referentially transparent. We need to treat such +# ops specially when optimizing; for example, CSE cannot coalesce two calls +# into one for ops in this set. NON_FUNCTIONAL_OPS: set[prims.PrimIDs | str] = { prims.PrimIDs.UNIFORM, "torch.uniform", # this doesn't exist as of the PR diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 9922fede58..772e65a84d 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -24,6 +24,7 @@ NumberProxy, Proxy, TensorProxy, + FloatProxy, variableify, unvariableify, CollectionProxy, @@ -59,6 +60,7 @@ convolution, ) from thunder.core.transform_common import dce +from thunder.core.vjp_utils import make_aug_forward_and_backward from thunder.extend import Executor import thunder.torch as ltorch @@ -538,24 +540,29 @@ def _flatten(bsym: BoundSymbol): # # -# Functions related to functionalizing TOMs +# Functions related to functionalizing ThunderOptimizedModules # # TODO Test with buffers def populate_grads(grads: list[TensorProxy], tom: None | torch.nn.Module = None, args=None, kwargs=None) -> None: idx: int = 0 - from thunder.common import ThunderOptimizedModule - - if tom is not None and isinstance(tom, ThunderOptimizedModule) and tom._additional_param_values is not None: - for p in tom._additional_param_values: - if p.requires_grad: + from thunder import ThunderModule, compile_data + + if isinstance(tom, ThunderModule) or thunder.compile_data(tom).using_jit: + assert args is not None, "populate grad needs args (and possibly kwargs) to work with ThunderModules" + if kwargs is None: + kwargs = {} + _, computation_inputs, _ = compile_data(tom).get_computation_and_inputs(*args, **kwargs) + for p in computation_inputs: + if isinstance(p, torch.Tensor) and p.requires_grad: # Supports grad accumulation (like when weight tying) if p.grad is not None: p.grad += grads[idx] else: p.grad = grads[idx] idx += 1 + return # Short-circuits if there are no args or kwargs if args is None and kwargs is None: @@ -587,7 +594,6 @@ def clear_grads(module: torch.nn.Module) -> None: b.grad = None -from thunder.core.script.noinline import noinline from thunder.core.interpreter import make_opaque from thunder.core.langctxs import langctx, Languages @@ -595,7 +601,7 @@ def clear_grads(module: torch.nn.Module) -> None: # TODO RC1 Replace with langctx def torchctx(fn): _fn = langctx(Languages.TORCH)(fn) - return make_opaque(noinline(_fn)) + return make_opaque(_fn) _grad_fn_map: dict[Any, Callable] = {} @@ -664,8 +670,16 @@ def _convert_element_type_prim_grad(a: Number | TensorProxy, dtype: type | dtype # NOTE prims.iota creates no grad associations register_grad(pids.IOTA, prims.iota) -# NOTE prims.uniform creates no grad associations -register_grad(pids.UNIFORM, prims.uniform) + +def _uniform_grad(shape, minval, maxval, *, device, dtype): + fwd, saved = uniform_aug_fwd(shape, minval, maxval, device=device, dtype=dtype) + g = get_grad(fwd) + _, gminval, gmaxval = uniform_backward(*saved, g) + put_grads((minval, maxval), (gminval, gmaxval)) + return fwd + + +register_grad(pids.UNIFORM, _uniform_grad) # # Reshaping and permuting operator grads @@ -863,12 +877,11 @@ def _abs_prim_grad(a: Number | TensorProxy) -> Number | TensorProxy: register_grad(pids.ABS, _abs_prim_grad) -@torchctx def _cos_prim_grad(a: Number | TensorProxy) -> Number | TensorProxy: - fwd = prims.abs(a) + fwd = prims.cos(a) g = get_grad(fwd) - put_grad(a, g * (-ltorch.sin(a))) + put_grad(a, g * (-prims.sin(a))) return fwd @@ -948,12 +961,11 @@ def _rsqrt_prim_grad(a: Number | TensorProxy, /) -> Number | TensorProxy: register_grad(pids.RSQRT, _rsqrt_prim_grad) -@torchctx def _sin_prim_grad(a: Number | TensorProxy) -> Number | TensorProxy: - fwd = prims.abs(a) + fwd = prims.sin(a) g = get_grad(fwd) - put_grad(a, g * ltorch.cos(a)) + put_grad(a, g * prims.cos(a)) return fwd @@ -1224,7 +1236,9 @@ def _embedding_prim_grad( # -def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any]) -> None | Callable: +def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable: + cd = get_compile_data() + executors_list = cd.executors_list if cd is not None else executors_list # Checks if the executor which has priority for this operation has a specific grad transform for it for ex in executors_list: if ex.can_execute_or_fuse(bsym): @@ -1564,7 +1578,7 @@ def _selector(eligible_nodes: list[Node]) -> int: return gradtrc - # NOTE This is a kludge to indicate that we shouldn't support PyTorch's autograd because + # NOTE This is a kludge to indicate that we shouldn't use PyTorch's autograd because # we're using our own autograd transform cfn._using_grad_transform = True @@ -1831,7 +1845,6 @@ def broadcast_in_dim_vmap( ) -> BatchedValue: bdim = a.batch_dim # TODO: remove this when shape and broadcast_dimensions become mandatory kwargs - # See https://github.com/Lightning-AI/lightning-thunder/issues/181 shape, _ = safe_zip(*shape) if len(broadcast_dimensions) > 0: broadcast_dimensions, _ = safe_zip(*broadcast_dimensions) @@ -2148,7 +2161,6 @@ def broadcast_in_dim_jvp(a: JVPDual, shape: tuple[JVPDual, ...], broadcast_dimen x, xd = a # TODO: shape and broadcast_dimensions should be tuples of ints # but for now it's a tuple of JVPDuals - # See https://github.com/Lightning-AI/lightning-thunder/issues/181 if len(shape) > 0 and isinstance(shape[0], JVPDual): shape, _ = safe_zip(*shape) if len(broadcast_dimensions) > 0 and isinstance(broadcast_dimensions[0], JVPDual): @@ -2408,43 +2420,30 @@ def zeros_like(x): # The augmented_primal function takes the primal values and returns the primal # result and the residuals (saved values for the backward). augmented_forward_impls = { - prims.PrimIDs.ABS: lambda x: (prims.abs(x), (x,)), prims.PrimIDs.ACOS: lambda x: (prims.acos(x), (x,)), prims.PrimIDs.ACOSH: lambda x: (prims.acosh(x), (x,)), - prims.PrimIDs.ADD: lambda x, y: (prims.add(x, y), tuple()), prims.PrimIDs.ASIN: lambda x: (prims.asin(x), (x,)), prims.PrimIDs.ASINH: lambda x: (prims.asinh(x), (x,)), prims.PrimIDs.ATAN: lambda x: (prims.atan(x), (x,)), prims.PrimIDs.ATANH: lambda x: (prims.atanh(x), (x,)), prims.PrimIDs.ATAN2: lambda x, y: (prims.atan2(x, y), (x, y)), - prims.PrimIDs.COS: lambda x: (prims.cos(x), (x,)), prims.PrimIDs.COSH: lambda x: (prims.cosh(x), (x,)), prims.PrimIDs.DIGAMMA: lambda x: (prims.digamma(x), (x,)), - prims.PrimIDs.DIV: lambda x, y: (prims.div(x, y), (x, y)), - prims.PrimIDs.ERF: lambda x: (prims.erf(x), (x,)), prims.PrimIDs.ERFC: lambda x: (prims.erfc(x), (x,)), prims.PrimIDs.ERFINV: lambda x: (prims.erfinv(x), (prims.erfinv(x),)), prims.PrimIDs.ERFCINV: lambda x: (prims.erfcinv(x), (prims.erfcinv(x),)), prims.PrimIDs.EXP2: lambda x: (prims.exp2(x), (prims.exp2(x),)), prims.PrimIDs.EXPM1: lambda x: (prims.expm1(x), (prims.expm1(x),)), prims.PrimIDs.LGAMMA: lambda x: (prims.lgamma(x), (x,)), - prims.PrimIDs.MUL: lambda x, y: (prims.mul(x, y), (x, y)), prims.PrimIDs.NDTRI: lambda x: (prims.ndtri(x), (prims.ndtri(x),)), - prims.PrimIDs.SIN: lambda x: (prims.sin(x), (x,)), prims.PrimIDs.SINH: lambda x: (prims.sinh(x), (x,)), - prims.PrimIDs.SUB: lambda x, y: (prims.sub(x, y), tuple()), prims.PrimIDs.SQRT: lambda x: (prims.sqrt(x), (prims.sqrt(x),)), - prims.PrimIDs.EQ: lambda x, y: (prims.eq(x, y), (x, y)), prims.PrimIDs.NE: lambda x, y: (prims.ne(x, y), (x, y)), - prims.PrimIDs.GE: lambda x, y: (prims.ge(x, y), (x, y)), prims.PrimIDs.GT: lambda x, y: (prims.gt(x, y), (x, y)), prims.PrimIDs.LE: lambda x, y: (prims.le(x, y), (x, y)), - prims.PrimIDs.LT: lambda x, y: (prims.lt(x, y), (x, y)), - prims.PrimIDs.LOG: lambda x: (prims.log(x), (x,)), prims.PrimIDs.LOG10: lambda x: (prims.log10(x), (x,)), prims.PrimIDs.LOG1P: lambda x: (prims.log1p(x), (x,)), prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)), - prims.PrimIDs.NEG: lambda x: (prims.neg(x), tuple()), prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)), prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)), } @@ -2454,55 +2453,32 @@ def zeros_like(x): # The backward function takes the residuals and cotangents and returns the # vector-Jacobian products for each argument. backward_impls = { - prims.PrimIDs.ABS: lambda x, g: g * prims.sign(x), prims.PrimIDs.ACOS: lambda x, g: -g / prims.sqrt(1.0 - x * x), prims.PrimIDs.ACOSH: lambda x, g: g * prims.rsqrt(x * x - 1.0), - prims.PrimIDs.ADD: lambda g: (g, g), prims.PrimIDs.ASIN: lambda x, g: g / prims.sqrt(1.0 - x * x), prims.PrimIDs.ASINH: lambda x, g: g * prims.rsqrt(1.0 + x * x), prims.PrimIDs.ATAN: lambda x, g: g / (1.0 + x * x), prims.PrimIDs.ATANH: lambda x, g: g / (1.0 - x * x), - prims.PrimIDs.COS: lambda x, g: prims.mul(g, -prims.sin(x)), prims.PrimIDs.COSH: lambda x, g: prims.mul(g, prims.sinh(x)), - prims.PrimIDs.DIV: lambda x, y, g: (g / y, -g * x / (y**2)), - prims.PrimIDs.ERF: lambda x, g: g * 2.0 / math.sqrt(math.pi) * prims.exp(-x * x), prims.PrimIDs.ERFC: lambda x, g: -g * 2.0 / math.sqrt(math.pi) * prims.exp(-x * x), prims.PrimIDs.ERFINV: lambda result, g: g * 0.5 * math.sqrt(math.pi) * prims.exp(result**2), prims.PrimIDs.ERFCINV: lambda result, g: -g * 0.5 * math.sqrt(math.pi) * prims.exp(result**2), prims.PrimIDs.EXP2: lambda result, g: g * result * math.log(2.0), prims.PrimIDs.EXPM1: lambda result, g: g * (result + 1.0), prims.PrimIDs.LGAMMA: lambda x, g: g * prims.digamma(x), - prims.PrimIDs.MUL: lambda x, y, g: (g * y, g * x), prims.PrimIDs.NDTRI: lambda result, g: g * prims.exp(0.5 * result**2) * math.sqrt(2.0 * math.pi), - prims.PrimIDs.SIN: lambda x, g: prims.mul(g, prims.cos(x)), prims.PrimIDs.SINH: lambda x, g: prims.mul(g, prims.cosh(x)), - prims.PrimIDs.SUB: lambda g: (g, -g), prims.PrimIDs.SQRT: lambda result, g: g / (2.0 * result), - prims.PrimIDs.FULL: NoPullback(num_args=2), - prims.PrimIDs.EQ: ZeroBackward(num_args=2), prims.PrimIDs.NE: ZeroBackward(num_args=2), - prims.PrimIDs.GE: ZeroBackward(num_args=2), prims.PrimIDs.GT: ZeroBackward(num_args=2), prims.PrimIDs.LE: ZeroBackward(num_args=2), - prims.PrimIDs.LT: ZeroBackward(num_args=2), - prims.PrimIDs.LOG: lambda x, g: g / x, prims.PrimIDs.LOG10: lambda x, g: g / (x * 2.302585092994046), prims.PrimIDs.LOG1P: lambda x, g: g / (x + 1), prims.PrimIDs.LOG2: lambda x, g: g / (x * 0.6931471805599453), - prims.PrimIDs.NEG: lambda g: -g, prims.PrimIDs.FMOD: lambda x, y, g: (g, -g * prims.trunc(x / y)), } -@dataclass(**default_dataclass_params) -class RuleInfo: - checker: Callable - rule: Callable - fw_fallback: Callable - bw_fallback: Callable - executor: Executor - - def register_augmented_forward(op): """Decorator to register an augmented forward implementation for a symbol. @@ -2520,40 +2496,6 @@ def decorator(func): return decorator -def register_augmented_forward_with_checker(executor, op, checker, rule): - """Decorator to register an augmented forward implementation for a symbol. - - Args: - executor (Executor): Executor to which the rule applies. - op (Ops): Symbol for which to register the augmented forward implementation. - checker (Callable): Function that checks if the rule should be applied. - rule (Callable): Function that applies the rule. - """ - fw_fallback = augmented_forward_impls.get(op, None) - bw_fallback = backward_impls.get(op, None) - augmented_forward_impls[executor, op] = RuleInfo(checker, rule, fw_fallback, bw_fallback, executor) - - -def deregister_augmented_forward_and_backward(op): - """Deregisters an augmented forward implementation and a backward - implementation for a symbol. - - Args: - op (Ops): Symbol for which to deregister the augmented forward - implementation and the backward implementation. - - Returns: - None - """ - # Restore the fallback implementation if it exists - if isinstance(augmented_forward_impls[op], RuleInfo): - backward_impls[op] = augmented_forward_impls[op].bw_fallback - augmented_forward_impls[op] = augmented_forward_impls[op].fw_fallback - else: - del augmented_forward_impls[op] - del backward_impls[op] - - def register_backward(op): """Decorator to register a backward implementation for a symbol. @@ -2623,29 +2565,6 @@ def polygamma_backward(n: int, a: Proxy, g): return None, g * polygamma(n + 1, a) -@register_augmented_forward(prims.PrimIDs.RSQRT) -def rsqrt_augmented(x): - """Augmented rsqrt operation. - - Args: - x (Variable): input tensor. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.rsqrt(x) - residuals = (primal,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.RSQRT) -def rsqrt_backward(result, g): - # An alternative derivation used by JAX is -0.5 * g * rsqrt(x) / x - # where rsqrt(x) and x are saved for the backwards pass. - # This derivation was selected because it avoids saving the input tensor. - return -0.5 * g * result**3.0 - - @register_backward(prims.PrimIDs.ATAN2) def atan2_backward(x, y, g): alpha = 1.0 / (x * x + y * y) @@ -2654,32 +2573,6 @@ def atan2_backward(x, y, g): return grad_x, grad_y -@register_augmented_forward(prims.PrimIDs.SUM) -def sum_aug_fwd(x, dims): - """Augmented sum operation. - - Args: - x (Variable): Tensor to be summed. - dims (Tuple[int, ...]): Dimensions to be summed. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.sum(x, dims) - residuals = ( - x.shape, - dims, - ) - - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.SUM) -def sum_backward(x_shape, reduced_dims, g): - # One return per positional argument of prims.sum - return restore_reduced_dims(g, reduced_dims, x_shape), None - - @register_augmented_forward(prims.PrimIDs.VAR) def var_aug_fwd(a, dim, *, correction): v = prims.var(a, dim, correction=correction) @@ -2701,13 +2594,6 @@ def var_backward(a, dim, correction, v, g): return (2 * g * (a - mean)) / normalization_scalar -@register_augmented_forward(prims.PrimIDs.VAR_MEAN) -def _var_mean_aug_fwd(a, dim, *, correction): - v, m = prims.var_mean(a, dim, correction=correction) - - return (v, m), (a, dim, correction, m) - - def n_elem_reduced(a_ndim, a_shape, dims): dims = utils.canonicalize_dims(a_ndim, dims) reduction_size = 1 @@ -2722,27 +2608,6 @@ def mean_backward(a_ndim, a_shape, dims, grad): return restore_reduced_dims(grad, dims, a_shape) * mean_local_grad -# TODO: fix division by zero when n_elem_reduced == 0 or when mean.numel == 0 -# by returning zeros_like(a) or similar. -# TODO: fix grad when correction > n_elem_reduced. -@register_backward(prims.PrimIDs.VAR_MEAN) -def _var_mean_bwd(a, dim, correction, mean, grad_v, grad_m): - n_elem_reduced = a.numel // mean.numel if a.numel != 0 else 1 - - def mean_backward(a, dims, grad): - mean_scale = 1.0 / n_elem_reduced - grad = restore_reduced_dims(grad, dims, a.shape) - return mean_scale * grad - - def var_backward(a, dims, correction, mean, grad): - normalization_scalar = n_elem_reduced - correction - grad = restore_reduced_dims(grad, dims, a.shape) - mean = restore_reduced_dims(mean, dims, a.shape) - return (2.0 * grad * (a - mean)) / normalization_scalar - - return var_backward(a, dim, correction, mean, grad_v) + mean_backward(a, dim, grad_m) - - @register_augmented_forward(prims.PrimIDs.PAD) def pad_aug_fwd(a, padding_value, padding_config): return VJPDual((prims.pad(a, padding_value, padding_config),), (a, padding_config)) @@ -2812,42 +2677,15 @@ def grad_chooser_backward(primal, x, x_shape, reduced_dims, g): return out -register_backward(prims.PrimIDs.AMAX)(grad_chooser_backward) register_backward(prims.PrimIDs.AMIN)(grad_chooser_backward) -# TODO: exact same for amin, argmax, argmin -@register_augmented_forward(prims.PrimIDs.AMAX) -def amax_aug_fwd(x, dims): - """Augmented amax operation. - - Args: - x (Variable): Tensor to compute amax on. - dims (Tuple[int, ...]): Dimensions to compute amax over. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.amax(x, dims) - - residuals = ( - primal, - x, - x.shape, - dims, - ) - - return VJPDual(primal, residuals) - - @register_augmented_forward(prims.PrimIDs.AMIN) def amin_aug_fwd(x, dims): """Augmented amin operation. - Args: x (Variable): Tensor to compute amin on. dims (Tuple[int, ...]): Dimensions to compute amin over. - Returns: VJPDual: Primal and residuals. """ @@ -2863,26 +2701,6 @@ def amin_aug_fwd(x, dims): return VJPDual(primal, residuals) -@register_augmented_forward(prims.PrimIDs.EXP) -def exp_aug_fwd(x): - """Augmented exp operation. - - Args: - x (Variable): Tensor to be exponentiated. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.exp(x) - residuals = (primal,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.EXP) -def exp_backward(result, g): - return g * result - - @register_augmented_forward(prims.PrimIDs.POW) def pow_aug_fed(x, y): """Augmented the pow operation. @@ -2927,106 +2745,11 @@ def tan_backward(result, g): return g * (1 + result * result) -@register_augmented_forward(prims.PrimIDs.TANH) -def tanh_aug_fwd(x): - """Augmented tanh operation. - - Args: - x (Variable): Tensor to be passed to tanh. - - Returns: - VJPDual: Primal and residuals. - """ - primal = prims.tanh(x) - residuals = (primal,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TANH) -def tanh_backward(result, g): - return g * (1.0 - result * result) - - # NOTE: Jax uses np.argsort in its transpose vjp computation def _argsort(seq): return sorted(range(len(seq)), key=seq.__getitem__) -@register_augmented_forward(prims.PrimIDs.TRANSPOSE) -def transpose_aug_fwd(a, permutation): - primal = prims.transpose(a, tuple(permutation)) - residuals = (permutation,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TRANSPOSE) -def transpose_backward(permutation, g): - undo = _argsort(permutation) - return prims.transpose(g, tuple(undo)) - - -@register_augmented_forward(prims.PrimIDs.RESHAPE) -def reshape_aug_fwd(a, shape): - primal = prims.reshape(a, shape) - residuals = (a.shape,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.RESHAPE) -def reshape_backward(orig_shape, g): - return prims.reshape(g, orig_shape) - - -@register_augmented_forward(prims.PrimIDs.SLICE) -def slice_aug_fwd(a, start_indices, end_indices, strides): - primal = prims.slice_prim(a, start_indices, end_indices, strides) - residuals = (a.shape, start_indices, end_indices, strides) - return VJPDual(primal, residuals) - - -# Adapted from https://github.com/google/jax/blob/main/jax/_src/lax/slicing.py#L768 -@register_backward(prims.PrimIDs.SLICE) -def slice_backward(shape, start_indices, end_indices, strides, g): - padding = None - if strides is None or np.all(np.equal(strides, 1)): - padding = tuple(zip(start_indices, np.subtract(shape, end_indices), (0,) * len(start_indices))) - else: - real_limits = np.add( - start_indices, - np.where(np.equal(g.shape, 0), 0, np.add(1, np.multiply(np.subtract(g.shape, 1), strides))), - ) - padding = tuple(zip(start_indices, np.subtract(shape, real_limits), np.subtract(strides, 1))) - - # We used NumPy arithmetics above, but the current infra expects Python ints. - padding = tree_map(int, padding) - result = prims.pad(g, const_as(0, g.dtype), padding) - - return result - - -@register_augmented_forward(prims.PrimIDs.BROADCAST_IN_DIM) -def broadcast_in_dim_aug_fwd(a: Proxy, shape: Sequence[int], broadcast_dimensions: Sequence[int]) -> VJPDual: - primal = prims.broadcast_in_dim(a, shape, broadcast_dimensions) - residuals = (a, shape, broadcast_dimensions) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.BROADCAST_IN_DIM) -def broadcast_in_dim_backward(a, shape, broadcast_dimensions, g): - from thunder.torch import sum - - # If g is None, then the primal was a constant and the pullback is zero. - # TODO: implement None propagation in the VJP infrastructure so that we don't need to do this. - if g is None: - return None, None, None - unit_dims = tuple(i for i, s in enumerate(a.shape) if s == 1) - bcast_dims = tuple(b for i, b in enumerate(broadcast_dimensions) if i not in unit_dims) - reduce_dims = tuple(s for i, s in enumerate(range(len(shape))) if i not in bcast_dims) - g = sum(g, reduce_dims) - g = unsqueeze(g, unit_dims) - return g - - @register_augmented_forward(prims.PrimIDs.DEVICE_PUT) def device_put_aug_fwd(a: TensorProxy, device: Device) -> TensorProxy: primal = prims.device_put(a, device) @@ -3039,19 +2762,6 @@ def device_put_backward(orig_device, g): return prims.device_put(g, orig_device), None -@register_augmented_forward(prims.PrimIDs.CONVERT_ELEMENT_TYPE) -def convert_element_type_aug_fwd(a: Proxy, dtype: dtypes.dtype) -> VJPDual: - primal = prims.convert_element_type(a, dtype) - residuals = (a.dtype if isinstance(a, TensorProxy) else (a.python_type if isinstance(a, NumberProxy) else type(a)),) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.CONVERT_ELEMENT_TYPE) -def convert_element_type_backward(a_dtype, g): - # perform cast back to input type during backward - return prims.convert_element_type(g, a_dtype), None - - @register_augmented_forward(prims.PrimIDs.CONVOLUTION) def convolution_aug_fwd( a: Proxy, @@ -3253,32 +2963,6 @@ def pad_transpose_and_push_groups_into_batches(t): return (input_grad, weight_grad, bias_grad) -@register_augmented_forward("torch.nn.functional.cross_entropy") -def cross_entropy_aug_fwd( - input: Proxy, - target: Proxy, - weight=None, - size_average=None, - ignore_index=-100, - reduce=None, - reduction="mean", - label_smoothing=0.0, -) -> VJPDual: - from thunder.torch import cross_entropy - - primal = cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing) - residuals = (input, target, weight, reduction, ignore_index, label_smoothing) - return VJPDual(primal, residuals) - - -@register_backward("torch.nn.functional.cross_entropy") -def cross_entropy_backward(input, target, weight, reduction, ignore_index, label_smoothing, g): - from thunder.torch import cross_entropy_backward - - ginput = cross_entropy_backward(g, input, target, weight, reduction, ignore_index, label_smoothing) - return ginput - - @register_augmented_forward("torch.log_softmax") def log_softmax_aug_fwd(input: TensorProxy, dim: int, *, dtype=None) -> VJPDual: from thunder.torch import log_softmax @@ -3408,64 +3092,6 @@ def softmax_backward(primal, dim, g): return primal * (g - (primal * g).sum(dim, keepdim=True)) -@register_augmented_forward(prims.PrimIDs.MATMUL) -def matmul_aug_fwd(a: TensorProxy, b: TensorProxy) -> VJPDual: - primal = prims.matmul(a, b) - residuals = (a, b) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.MATMUL) -def matmul_backward(a, b, g): - from thunder.torch import sum - - last_dim = (-1,) - first_dim = (-2,) - if a.ndim == 1 and b.ndim == 1: - return g * b, g * a - - if b.ndim == 1: - ga = unsqueeze(g, last_dim) @ unsqueeze(b, last_dim).mT - gb = a.mT @ unsqueeze(g, last_dim) - if g.ndim > 1: - gb = squeeze(gb, last_dim) - gb = sum(gb, tuple(range(gb.ndim - 1))) - return ga, gb - - if a.ndim == 1: - ga = unsqueeze(g, first_dim) @ b.mT - if g.ndim > 1: - ga = sum(ga, tuple(range(ga.ndim - 1))) - gb = unsqueeze(a, first_dim).mT @ unsqueeze(g, first_dim) - return ga, gb - - return g @ b.mT, a.mT @ g - - -@register_augmented_forward(prims.PrimIDs.LINEAR) -def linear_aug_fwd(a: TensorProxy, b: TensorProxy, c: TensorProxy | None) -> VJPDual: - primal = prims.linear(a, b, c) - residuals = (a, b, c) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.LINEAR) -def linear_backward(a, b, c, g): - from thunder.torch import matmul, sum - - first_dim = (-2,) - ga = matmul(g.reshape(-1, g.shape[-1]), b).reshape(a.shape) - if a.ndim == 1: - gb = matmul(unsqueeze(g, first_dim).mT, unsqueeze(a, first_dim)) - else: - gb = matmul(g.reshape(-1, g.shape[-1]).mT, a.reshape(-1, a.shape[-1])) - assert list(gb.shape) == list(b.shape), f"linear_backward: {gb.shape} != {b.shape}" - if c is None: - return ga, gb, None - gc = sum(g, tuple(range(g.ndim - 1))) if g.ndim > 1 else g - return ga, gb, gc - - def iter_bound_symbols(bound_symbols): """Iterate over bound symbols, skipping symbols that are not supported by the transforms infrastructure. @@ -3558,31 +3184,6 @@ def decomposed_fn_backward_rule(decomposed_fn, args, kwargs, saved_for_backward, return result -@register_augmented_forward(prims.PrimIDs.CAT) -def cat_aug_fwd(tensors: list[TensorProxy], dim: int) -> VJPDual: - primal = prims.cat(tensors, dim) - residuals = ( - type(tensors), - [t.shape[dim] for t in tensors], - dim, - ) - - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.CAT) -def cat_backward( - tensors_seq_type: type, tensor_dim_lens: list[int], dim: int, g: TensorProxy -) -> tuple[Sequence[TensorProxy]]: - grads = [] - - slice_start = 0 - for dim_len in tensor_dim_lens: - grads.append(slice_in_dim(g, slice_start, slice_start + dim_len, dim=dim)) - slice_start += dim_len - return (tensors_seq_type(grads),) - - @register_augmented_forward("torch.Tensor.contiguous") @register_augmented_forward("torch.contiguous") def contiguous_aug_fwd(x: TensorProxy, /, *, memory_format: torch.memory_format = torch.contiguous_format) -> VJPDual: @@ -3599,25 +3200,11 @@ def contiguous_backward(*residuals_and_grad) -> TensorProxy: return g -@register_augmented_forward(prims.PrimIDs.WHERE) -def where_aug_fwd(condition: TensorProxy, x: TensorProxy, y: TensorProxy) -> VJPDual: - primal = prims.where(condition, x, y) - residuals = (condition,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.WHERE) -def where_backward(condition, g): - return prims.where(condition, g, 0.0), prims.where(condition, 0.0, g) - - -@register_augmented_forward(prims.PrimIDs.RECIPROCAL) def reciprocal_aug_fwd(a: TensorProxy) -> VJPDual: primal = reciprocal(a) return VJPDual(primal, (primal,)) -@register_backward(prims.PrimIDs.RECIPROCAL) def reciprocal_backward(primal, g): return -g * primal * primal @@ -3631,38 +3218,6 @@ def reciprocal_joint_forward_backward_rule(a: TensorProxy) -> TensorProxy: return result -@register_augmented_forward(prims.PrimIDs.SQUEEZE) -def squeeze_aug_fwd(a: TensorProxy, dims: Sequence[int]) -> VJPDual: - primal = squeeze(a, dims) - residuals = (dims,) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.SQUEEZE) -def squeeze_backward(dims: Sequence[int], g: TensorProxy) -> TensorProxy: - return unsqueeze(g, dims) - - -@register_augmented_forward(prims.PrimIDs.TAKE) -def take_aug_fwd(x: TensorProxy, index: TensorProxy, dim: int) -> VJPDual: - primal = prims.take(x, index, dim) - residuals = ( - x.shape, - x.device, - x.dtype, - index, - dim, - ) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TAKE) -def take_backward( - shape: Sequence[int], device: Device, dtype: dtypes.dtype, index: TensorProxy, dim: int, g: TensorProxy -): - return prims.index_add(prims.full(shape, fill_value=0, device=device, dtype=dtype), index, g, dim) - - @register_augmented_forward("torch.index_put") def index_put_aug_fwd( a: TensorProxy, /, indices: Sequence[TensorProxy], values: TensorProxy, accumulate: bool = False @@ -3701,33 +3256,11 @@ def index_put_backward(indices: Sequence[TensorProxy], values: TensorProxy, accu return clang.index_put(g, indices, ltorch.zeros_like(values), False), g_values -@register_augmented_forward(prims.PrimIDs.TAKE_ALONG_AXIS) -def take_along_axis_aug_fwd(x: TensorProxy, index: TensorProxy, dim: int) -> VJPDual: - primal = prims.take_along_axis(x, index, dim) - residuals = ( - x.shape, - x.device, - x.dtype, - index, - dim, - ) - return VJPDual(primal, residuals) - - -@register_backward(prims.PrimIDs.TAKE_ALONG_AXIS) -def take_along_axis_backward( - shape: Sequence[int], device: Device, dtype: dtypes.dtype, index: TensorProxy, dim: int, g: TensorProxy -): - return prims.scatter_add(prims.full(shape, fill_value=0, device=device, dtype=dtype), index, g, dim) - - -@register_augmented_forward(prims.PrimIDs.UNIFORM) def uniform_aug_fwd(shape, minval, maxval, *, device, dtype): primal = prims.uniform(shape, minval, maxval, device=device, dtype=dtype) return VJPDual(primal, (primal, minval, maxval)) -@register_backward(prims.PrimIDs.UNIFORM) def uniform_backward(primal, minval, maxval, g): # uniform is implemented as (maxval - minval) * uniform(shape, 0, 1) + minval unscaled_primal = (primal - minval) / (maxval - minval) @@ -3739,29 +3272,21 @@ def uniform_backward(primal, minval, maxval, g): nondifferentiable_vjp_symbols = (prims.PrimIDs.BITWISE_AND, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL) -def get_executor_specific_aug_fwd_rule(symbol) -> RuleInfo | None: - """Get executor specific augmented forward rule. +def is_constant_for_vjp(symbol: prims.Symbol) -> bool: + """Check if a symbol is constant for the VJP transform. Args: - symbol (prims.Symbol): Symbol to get the rule for. + symbol (prims.Symbol): Symbol to check. Returns: - RuleInfo: Rule info for the symbol. + bool: True if the symbol is constant, False otherwise. """ - cd = get_compile_data() - if cd is None: - return None - - # Search for the executor specific rules. When there are multiple rules - # for the same symbol, we use the left-most executor in the list (i.e. - # the one with the highest priority) and we fallback to the next one if - # the checker returns False. - for executor in cd.executors_list: - candidate = augmented_forward_impls.get((executor, symbol.sym.id)) - if isinstance(candidate, RuleInfo) and candidate.checker(*symbol.args, **symbol.kwargs): - return candidate - - return None + are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) + return ( + are_all_args_non_differentiable + or symbol.are_all_args_constant + or symbol.sym.id in nondifferentiable_vjp_symbols + ) def vjp_symbol_mapper(symbol: prims.Symbol, *args, **kwargs): @@ -3776,7 +3301,7 @@ def vjp_symbol_mapper(symbol: prims.Symbol, *args, **kwargs): Callable: A function that computes the VJP of the symbol. """ # Constant case - if symbol.are_all_args_constant or symbol.sym.id in nondifferentiable_vjp_symbols: + if is_constant_for_vjp(symbol): def vjp_impl_const(symbol, *args, **kwargs): args, kwargs = tree_map(lambda x: x.primal if isinstance(x, VJPDual) else x, (args, kwargs)) @@ -3790,15 +3315,8 @@ def vjp_impl_const(symbol, *args, **kwargs): # Normal case, we have a proxy tangent vjp_impl = augmented_forward_impls.get(symbol.sym.id) - vjp_impl = get_executor_specific_aug_fwd_rule(symbol) or vjp_impl - - if isinstance(vjp_impl, RuleInfo): - # We should use this rule only if checker returns True for the current - # symbol's arguments - if vjp_impl.checker(*symbol.args, **symbol.kwargs): - vjp_impl = vjp_impl.rule - else: - vjp_impl = vjp_impl.fw_fallback + if _get_gradfn(symbol) is not None: + vjp_impl, backward_fn = make_aug_forward_and_backward(symbol) if vjp_impl is None: # We could not find a VJP for this symbol, so we try to decompose it @@ -3809,6 +3327,7 @@ def vjp_impl_const(symbol, *args, **kwargs): # It could be a torch.dropout with 0.0 probability, so we skip it if symbol.sym.id == "torch.nn.functional.dropout": return None + print(f"VJP for {symbol} is not implemented") raise NotImplementedError(f"VJP for {symbol.sym.id} is not implemented") def _vjp_impl(*args, **kwargs): @@ -3842,6 +3361,9 @@ def check_bsym_for_vjp(bsym): if bsym.sym.id in backward_impls and bsym.sym.id in augmented_forward_impls: return True + if bsym.sym.id in _grad_fn_map: + return True + # We could not find a VJP for this symbol, so we try to decompose it # into sub-symbols and check if they are supported if len(bsym.subsymbols) > 0 and not bsym.sym.is_prim: @@ -3925,6 +3447,8 @@ def put_grad(v: Variable, val: Any) -> None: elif isinstance(v, Sequence) and val is None: # broadcast None to the right shape safe_map(put_grad, v, [None] * len(v)) + elif isinstance(v, Sequence) and isinstance(val, Sequence): + safe_map_flat(put_grad, v, val) else: # Skip writing to constants pass @@ -3943,7 +3467,7 @@ def put_grad(v: Variable, val: Any) -> None: # Otherwise, we will need to rewrite the pullback functions cotangents = tree_flatten(cotangents)[0] residuals = forward_env[symbol_output[0].name].residuals - if symbol.are_all_args_constant or symbol.sym.id in nondifferentiable_vjp_symbols: + if is_constant_for_vjp(symbol): # We can skip the pullback if all the arguments are constant continue @@ -3955,17 +3479,15 @@ def put_grad(v: Variable, val: Any) -> None: if symbol.sym.id == "torch.nn.functional.dropout" and not symbol.subsymbols: # We can skip the pullback if the dropout probability is 0.0 # Assuming that the dropout symbol has the same output and argument - # https://github.com/Lightning-AI/lightning-thunder/issues/906 assert symbol.output.name == symbol.args[0].name, "Dropout symbol has a different output and argument" if symbol.args[1] == 0.0 or symbol.args[2] is False: continue backward = backward_impls.get(symbol.sym.id) aug_forward = augmented_forward_impls.get(symbol.sym.id) - aug_forward = get_executor_specific_aug_fwd_rule(symbol) or aug_forward - if isinstance(aug_forward, RuleInfo): - backward = backward_impls[aug_forward.executor, symbol.sym.id] + if _get_gradfn(symbol) is not None: + aug_forward, backward = make_aug_forward_and_backward(symbol) if backward is None: if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs): @@ -3980,9 +3502,11 @@ def put_grad(v: Variable, val: Any) -> None: # If the backward returns a dict, we assume that it is a dict of # forward arguments to the corresponding # gradients/cotangents/adjoints/sensitivities. + used_names = set() for i, (k, v) in enumerate(inspect.signature(aug_forward).parameters.items()): if v.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): put_grad(symbol.args[i], result.get(k, None)) + used_names.add(k) # For developer convenience, we allow using the name from the # forward meta in addition to the name from the augmented forward @@ -3991,7 +3515,8 @@ def put_grad(v: Variable, val: Any) -> None: # precedence. for i, (k, v) in enumerate(inspect.signature(symbol.sym.meta).parameters.items()): if v.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): - put_grad(symbol.args[i], result.get(k, None)) + if k not in used_names: + put_grad(symbol.args[i], result.get(k, None)) continue if not isinstance(result, Sequence): @@ -4028,7 +3553,7 @@ def is_differentiable(arg): result = tuple(next(iter_result) if is_differentiable(arg) else None for arg in symbol.args) - # See https://github.com/Lightning-AI/lightning-thunder/issues/977. + # See "Backward impl for ops of the type Sequence[TensorProxy], ... -> ... results in None grads." # This is a temporary workaround. if symbol.sym.id in (prims.PrimIDs.CAT, "torch.cat", "torch.stack"): safe_map_flat(put_grad, symbol.args, result) @@ -4065,7 +3590,8 @@ def vjp_call_metafunc(detached: bool, primals, cotangents, trace: Trace, **kwarg # TODO: Can't use a Symbol here because mixed executor sybsymbols seem to be -# unsupported. See https://github.com/Lightning-AI/lightning-thunder/issues/1308 +# unsupported. See issue "Could not find an executor for bound symbol when its subsymbols +# are not fully supported by a single executor" vjp_call = partial( vjp_call_metafunc, False ) # Symbol(id=Transforms.VjpOp, name="vjp_call", meta=partial(vjp_call_metafunc, False)) @@ -4190,7 +3716,7 @@ def unpacking_fn(saved_for_backward, cotangents): # NOTE: Returning namedtuples from compiled functions doesn't work. See: -# https://github.com/Lightning-AI/lightning-thunder/issues/881 +# "Allow returning namedtuples from compiled functions" # Note [Grad forward output spec] # If it did work it would be nice to use this namedtuple # instead of the plain tuple or dict that we're using now. @@ -4372,10 +3898,13 @@ def decorator(func): def maybe_downcast_to(dtype, args): allowed_downcast_types = (dtypes.float16, dtypes.bfloat16, dtypes.float32) - if all(tree_map(lambda a: a.dtype in allowed_downcast_types, args)): - return tree_map(lambda a: maybe_convert_to_dtype(a, dtype), args) - else: - return args + + def map_fn(a): + if isinstance(a, TensorProxy) and a.dtype in allowed_downcast_types: + return maybe_convert_to_dtype(a, dtype) + return a + + return tree_map(map_fn, args) @register_autocast_rule("torch.matmul") diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 9dfb95e767..1edebe70e4 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -812,8 +812,8 @@ def safe_map(f, *args): def safe_map_flat(f, *args): def convert_sequences_to_tuple(x): - if isinstance(x, Sequence): - return tuple(x) + if not isinstance(x, str) and isinstance(x, Sequence) and not isinstance(x, Proxy): + return tuple(convert_sequences_to_tuple(y) for y in x) return x args_flat_spec = safe_map(lambda x: tree_flatten(convert_sequences_to_tuple(x)), args) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 5f5a67febd..3c8128ffeb 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -1,4 +1,3 @@ -import copy import inspect from inspect import Parameter, Signature from itertools import chain @@ -13,6 +12,9 @@ from thunder.core.transform_common import dce +_cache = {} + + def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable]: """ Given a bound symbol, return a pair of forward and backward functions @@ -33,13 +35,20 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable A pair of forward and backward functions. """ import thunder - from thunder.core.transforms import _grad_fn_map + from thunder.common import _make_cache_key + from thunder.core.transforms import _get_gradfn, eval_trace - joint_forward_backward = _grad_fn_map.get(bsym.sym.id, None) + joint_forward_backward = _get_gradfn(bsym) utils.check( joint_forward_backward is not None, lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", ) + + key = (bsym.sym, subkey := _make_cache_key(bsym.args, bsym.kwargs)) + cached_result = _cache.get(key, None) if subkey is not None else None + if cached_result is not None: + return cached_result + joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) consumers = utils.consumers(joint_trace) @@ -67,20 +76,31 @@ def find_backward_output(forward_input): bw_outputs_args = tree_map(find_backward_output, joint_trace.args) bw_outputs_kwargs = tree_map(find_backward_output, joint_trace.kwargs) meta_parameters = inspect.signature(bsym.sym.meta).parameters + meta_parameters = { + name: param + for name, param in meta_parameters.items() + if param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY) + } bw_outputs = {name: bw_output for name, bw_output in utils.safe_zip(meta_parameters, bw_outputs_args)} bw_outputs = bw_outputs | bw_outputs_kwargs flat_bw_outputs, _ = tree_flatten(bw_outputs) - backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, bw_inputs) - unpacking_ops = ( + backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0]) + skip = ( prims.PrimIDs.UNPACK_EMPTY_DICT, prims.PrimIDs.UNPACK_KEY, prims.PrimIDs.UNPACK_SEQUENCE, prims.PrimIDs.UNPACK_TRIVIAL, + prims.PrimIDs.GET_GRAD, ) - backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in unpacking_ops] + backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in skip] backward_bsyms.append(prims.python_return.bind(bw_outputs, output=())) + forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] + forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] + forward_bsyms = utils.find_producer_symbols(joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies) + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in forward_bsyms] + # Find required info from forward trace for backward trace backward_producers = utils.producers(backward_bsyms) saved_for_backward = [] @@ -91,7 +111,37 @@ def find_backward_output(forward_input): if arg not in backward_producers and variableify(arg) not in map(variableify, tree_flatten(bw_inputs)[0]): saved_for_backward.append(arg) - backward_params = [Parameter(x.name, Parameter.POSITIONAL_OR_KEYWORD) for x in chain(saved_for_backward, bw_inputs)] + saved_for_backward = list({variableify(arg): arg for arg in saved_for_backward}.values()) + + # Augment forward trace to include saved_for_backward as output + augmented_forward_trace = from_trace(joint_trace) + augmented_forward_trace.bound_symbols = [ + b for b in joint_trace.bound_symbols if b.sym.id not in (PrimIDs.PUT_GRAD, PrimIDs.GET_GRAD) + ] + return_bsym = augmented_forward_trace.bound_symbols[-1] + assert return_bsym.sym.id == PrimIDs.RETURN + augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( + (joint_trace.output, saved_for_backward), output=() + ) + # Remove put/get grad and backward symbols from augmented forward trace + augmented_forward_trace = dce(augmented_forward_trace) + + # Check if any of the bound symbols in the backward trace are also in the + # augmented forward trace + # If so, remove them from the backward trace + same_bsyms = set(augmented_forward_trace.bound_symbols) & set(backward_bsyms) + if same_bsyms: + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in same_bsyms] + additional_saved = [o for bsym in same_bsyms for o in bsym.flat_proxy_outs] + saved_for_backward += list({variableify(arg): arg for arg in additional_saved}.values()) + augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( + (joint_trace.output, saved_for_backward), output=() + ) + + backward_params = [ + Parameter(getattr(x, "name", f"arg{i}"), Parameter.POSITIONAL_OR_KEYWORD) + for i, x in enumerate(chain(saved_for_backward, bw_inputs)) + ] backward_signature = Signature(backward_params) def backward_fn(): @@ -106,15 +156,15 @@ def backward_fn(): backward_trace.kwargs = {} backward_trace.bound_symbols = backward_bsyms - # Augment forward trace to include saved_for_backward as output - augmented_forward_trace = from_trace(joint_trace) - augmented_forward_trace.bound_symbols = copy.copy(joint_trace.bound_symbols) - return_bsym = augmented_forward_trace.bound_symbols[-1] - assert return_bsym.sym.id == PrimIDs.RETURN - augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( - (joint_trace.output, saved_for_backward), output=() - ) - # Remove put/get grad from augmented forward trace - augmented_forward_trace = dce(augmented_forward_trace) + # Creating new functions instead of using partial to avoid limitations in + # codeutils.get_siginfo + # https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/codeutils.py#L349-L353 + def fw_fn(*args, **kwargs): + return eval_trace(augmented_forward_trace, *args, **kwargs) + + def bw_fn(*args, **kwargs): + return eval_trace(backward_trace, *args, **kwargs) + + _cache[key] = fw_fn, bw_fn - return augmented_forward_trace.python_callable(), backward_trace.python_callable() + return fw_fn, bw_fn diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index ecb280149f..0a0eb943d9 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -68,9 +68,11 @@ def skip_data_parallel_grad_sync() -> None: def _sync_grads(module: torch.nn.Module) -> None: + import thunder + params_with_grad = [p for p in module.parameters() if p.grad is not None] grads = [p.grad for p in params_with_grad] - process_group = module._lc_cd.process_group_for_ddp + process_group = thunder.compile_data(module).process_group_for_ddp torch._foreach_div_(grads, process_group.size()) with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm: for g in grads: @@ -234,7 +236,7 @@ def main(): # Starts broadcasts # TODO Make these broadcast asyncs # TODO Perform up to two broadcasts at a time - # https://github.com/Lightning-AI/lightning-thunder/issues/727 + # See issue "Update ddp to use async broadcasts" # TODO "Bucket" small tensors together before broadcasting with torch.no_grad(): for param in model.parameters(): diff --git a/thunder/distributed/transforms/fsdp.py b/thunder/distributed/transforms/fsdp.py index c07afce08e..ad92f5e6db 100644 --- a/thunder/distributed/transforms/fsdp.py +++ b/thunder/distributed/transforms/fsdp.py @@ -198,7 +198,7 @@ def maybe_swap_proxies_of_bsym_and_update_swap_map(bsym: BoundSymbol) -> bool: lambda: f"{variableify(param)} not found in param set: {(variableify(p) for p in self.original_params)}", ) if param not in self.param_to_bucket: - # This path is hihly likely to be backward reduce-scatter bucketing: + # This path is highly likely to be backward reduce-scatter bucketing: # when a param does not require grad, a trace could still have reduce-scatter # and wait in its trace while the grad in the return statement is already # replaced with `None`. diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index f34f98f3ad..9d9fa7bf42 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -52,7 +52,6 @@ def key(node: Node) -> int: # TODO: Currently prefer the most memory-efficient way for ZeRO3, -# https://github.com/Lightning-AI/lightning-thunder/issues/1925 # Need a strategy to balance the efficiency # and memory usage in the future def sort_waits_for_zero3(execution_trace): diff --git a/thunder/examine/__init__.py b/thunder/examine/__init__.py index 23661586bb..c75bc5f3b7 100644 --- a/thunder/examine/__init__.py +++ b/thunder/examine/__init__.py @@ -43,7 +43,7 @@ def __exit__(self, exc_type, exc_value, traceback): # TODO Maybe have this print additional information and return more metadata? -# TODO Accept kwargs for compile (like langctx) +# TODO Accept kwargs for jit (like langctx) # TODO Add profiling (or profiling option) to determine if we have a slowdown # TODO If an error occurs, try to minify the program to produce a smaller sample to reproduce the error def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs): @@ -141,7 +141,7 @@ def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs): return - # Step 3 Attempts to compile the function using lightning.compile + # Step 3 Attempts to compile the function using thunder.jit try: cfn = thunder.jit(fn) except Exception as e: @@ -151,7 +151,7 @@ def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs): ) raise e - # Step 4 Attemps to execute the function using lightning.compile + # Step 4 Attempt to execute the function using thunder.jit lc_result: Any try: lc_result = cfn(*args, **kwargs) diff --git a/thunder/examine/memory_caculation.py b/thunder/examine/memory_caculation.py index 45cc74a9a0..149349748c 100644 --- a/thunder/examine/memory_caculation.py +++ b/thunder/examine/memory_caculation.py @@ -22,7 +22,7 @@ "torch_wait_prim_impl", ) -# A whitelist registry of symbols that require special memory calculation; +# A registry of symbols that require special memory calculation; # if not registered, the default memory calculation function is used. memory_calculate_impls: dict[Symbol, Callable] = dict() diff --git a/thunder/executors/apex_entropyex.py b/thunder/executors/apex_entropyex.py index 7d5eec13fb..818199ad5b 100644 --- a/thunder/executors/apex_entropyex.py +++ b/thunder/executors/apex_entropyex.py @@ -10,11 +10,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import Symbol from thunder.core.utils import check, same_shape -from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, sum_backward -from thunder.core.transforms import ( - register_augmented_forward_with_checker, - register_backward, -) +from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, restore_reduced_dims from thunder.extend import OperatorExecutor, register_executor @@ -49,7 +45,8 @@ def apex_available() -> bool: # TODO Consider performing the reduction as part of a traceable epilogue -# See https://github.com/Lightning-AI/lightning-thunder/issues/1357 +# See "Update the apex cross entropy executor to put its reduction in a +# traceable epilogue" # NOTE Apex's cross entropy doesn't accept ignore_index >= 0, or the weight, size_average, or reduce parameters def _apex_cross_entropy_impl( a: torch.Tensor, @@ -196,78 +193,6 @@ def _cross_entropy_checker( return True -# Check out -# https://github.com/Lightning-AI/lightning-thunder/blob/main/dev_tutorials/thunder-add-vjp-rule.md -# for a tutorial on how to add a VJP rule for any Symbol. We use our new -# primitives to register a VJP rule for torch.nn.functional.cross_entropy. This -# function is registered as the augmented forward rule for -# torch.nn.functional.cross_entropy below -def apex_cross_entropy_forward_rule( - a, - target, - weight=None, - size_average=None, - ignore_index=-100, - reduce=None, - reduction="mean", - label_smoothing=0.0, -): - loss, max_log_sum_exp = apex_xentropy( - a, - target=target, - reduction=reduction, - label_smoothing=label_smoothing, - ) - primal = loss - saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing) - return primal, saved_for_backward - - -register_augmented_forward_with_checker( - apex_ex, - ltorch.cross_entropy.id, - _cross_entropy_checker, - apex_cross_entropy_forward_rule, -) - - -# This function is the backward rule for torch.nn.functional.cross_entropy. It -# accepts the primal output and saved_for_backward from the forward pass and -# returns the backward output. The backward output is a tuple of the backward -# output for each differentiable Tensor input to the forward pass. In this case, -# the forward pass has 1 such input, so the backward output is a single Tensor. -# This function is registered as the backward rule for -# torch.nn.functional.cross_entropy -@register_backward((apex_ex, ltorch.cross_entropy.id)) -def apex_cross_entropy_backward_rule( - logits, - labels, - max_log_sum_exp, - reduction, - smoothing, - grad, -): - from thunder.core.transforms import mean_backward, sum_backward - - if reduction == "mean": - grad = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), grad) - elif reduction == "sum": - grad = sum_backward(max_log_sum_exp.shape, (0,), grad) - elif reduction == "none": - pass - else: - raise ValueError(f"Invalid reduction: {reduction}") - - grad_logits = apex_xentropy_bwd( - grad, - logits, - target=labels, - max_log_sum_exp=max_log_sum_exp, - label_smoothing=smoothing, - ) - return grad_logits - - # Translate calls from torch.nn.functional.cross_entropy to apex_xentropy (when the checker above returns True) def _cross_entropy_transform( a: TensorProxy, @@ -304,7 +229,7 @@ def _apex_cross_entropy_grad( if reduction == "mean": g = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), g) elif reduction == "sum": - g, _ = sum_backward(max_log_sum_exp.shape, (0,), g) + g = restore_reduced_dims(g, (0,), max_log_sum_exp.shape) # NOTE Apex's xentropy bwd requires the grad computation to be performed in fp32 a_ = a.contiguous() diff --git a/thunder/executors/cudnn_layernormex.py b/thunder/executors/cudnn_layernormex.py index 1a84811278..b6f260cdd7 100644 --- a/thunder/executors/cudnn_layernormex.py +++ b/thunder/executors/cudnn_layernormex.py @@ -19,9 +19,7 @@ def cudnn_available() -> bool: return CUDNN_AVAILABLE -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880~ - +# WARNING: cudnn layernorm executor is experimental. Tests that use cudnn might fail. from dataclasses import dataclass from functools import lru_cache from typing import Union, Dict diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 20759bb085..75494cff5a 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -21,9 +21,6 @@ def cudnn_available() -> bool: return CUDNN_AVAILABLE -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880~ - from dataclasses import dataclass from functools import lru_cache from typing import Union, Dict @@ -38,8 +35,6 @@ def cudnn_available() -> bool: get_grad, put_grad, put_grads, - register_augmented_forward_with_checker, - register_backward, ) from thunder.extend import OperatorExecutor, register_executor import thunder.torch as ltorch @@ -341,7 +336,11 @@ def _cudnn_sdpa_forward_checker( if d % 8 != 0 or d > 128: return False - return True + is_backward_supported = _cudnn_sdpa_backward_checker( + query, key, value, attn_mask, dropout_p, is_causal, scale=scale + ) + + return True and is_backward_supported @langctx("torch") @@ -604,99 +603,6 @@ def cudnn_sdpa_bwd_impl( ) -@langctx("torch") -def cudnn_sdpa_aug_fw_rule_checker( - query: TensorProxy, - key: TensorProxy, - value: TensorProxy, - attn_mask: None | TensorProxy, - dropout_p: float, - is_causal: bool, - *, - scale: None | float, -) -> bool: - from thunder.core.compile_data import get_compile_data - - cd = get_compile_data() - if cudnn_ex in cd.executors_list: - is_forward_supported = _cudnn_sdpa_forward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - is_backward_supported = _cudnn_sdpa_backward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - return is_forward_supported and is_backward_supported - return False - - -def cudnn_sdpa_aug_fw_rule( - query, - key, - value, - attn_mask=None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: float | None = None, -): - output, softmax_stats, seed, offset = cudnn_sdpa_fwd( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - saved_for_backward = ( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - output, - softmax_stats, - seed, - offset, - ) - return output, saved_for_backward - - -register_augmented_forward_with_checker( - cudnn_ex, - "torch.nn.functional.scaled_dot_product_attention", - cudnn_sdpa_aug_fw_rule_checker, - cudnn_sdpa_aug_fw_rule, -) - - -@register_backward((cudnn_ex, "torch.nn.functional.scaled_dot_product_attention")) -def cudnn_sdpa_backward_rule( - query: Proxy, - key: Proxy, - value: Proxy, - attn_mask: None | Proxy, - dropout_p: float, - is_causal: bool, - scale: None | float, - out: Proxy, - softmax_stats: Proxy, - seed: Proxy, - offset: Proxy, - grad_out: Proxy, -): - return cudnn_sdpa_bwd( - grad_out, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - out, - softmax_stats, - seed, - offset, - scale=scale, - ) - - @langctx("torch") def _cudnn_sdpa_transform( query: TensorProxy, @@ -729,7 +635,7 @@ def _cudnn_sdpa_grad( ) g = get_grad(primal) - grad_query, grad_key, grad_val, grad_attn_mask = cudnn_sdpa_bwd( + grads = cudnn_sdpa_bwd( g, query, key, @@ -743,6 +649,11 @@ def _cudnn_sdpa_grad( offset, scale=scale, ) + if attn_mask is None: + grad_query, grad_key, grad_val = grads + else: + grad_query, grad_key, grad_val, grad_attn_mask = grads + put_grads((query, key, value), (grad_query, grad_key, grad_val)) if attn_mask is not None: put_grad(attn_mask, grad_attn_mask) diff --git a/thunder/executors/nvfuserex.py b/thunder/executors/nvfuserex.py index 3e25890dde..27d4ce721b 100644 --- a/thunder/executors/nvfuserex.py +++ b/thunder/executors/nvfuserex.py @@ -33,7 +33,6 @@ def required_nvfuser_version() -> LooseVersion: return LooseVersion("0.0.1") -# NOTE We require nvFuser version 0.0.1 or greater def nvfuser_available() -> bool: v = nvfuser_version() return v is not None and v >= required_nvfuser_version() diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 14fba83ffe..7a5ebfe5ae 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -206,7 +206,7 @@ def create_fd( # NOTE nvFuser's default max length is 1024 operations at the time of this writing # This arbitrarily increases it to 9999 # TODO Review splititng very large fusions or removing the max length restriction completely - # See https://github.com/Lightning-AI/lightning-thunder/issues/901 + # See "Very large nvFuser fusions hit max_length" fd = FusionDefinition(max_length=9999) with fd: # NOTE Adding constants is disabled for the moment in favor of definining them inline @@ -524,8 +524,7 @@ class nvFuserExecutor(FusionExecutor): def __init__(self): super().__init__("nvfuser", version=nvfuser.version()) - # TODO: Replace this with a query to current CompileData after - # https://github.com/Lightning-AI/lightning-thunder/pull/1517 is merged + # TODO: Replace this with a query to a compile option self._use_rematerialization = True fuel_str = os.getenv("NVFUSER_OPTIMIZATION_FUEL") @@ -762,8 +761,6 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: # TODO has_cuda_input_or_output is too restrictive a check on what should be fused # TODO check whether a function would output a CPU tensor? -- can nvFuser fuse such operations? # ex. device_put to a CPU device from a CUDA device - # (mruberry) I don't know if nvFuser even attempts to fuse any operation that can go - # cross-device today def _should_fuse(a: Node, b: Node): def _can_fuse_node(n: Node): # if already merged, then node can be fused @@ -842,8 +839,7 @@ def _can_fuse_node(n: Node): # Some of the operations might be better placed with its consumers (for # example residual connection in transformer block). This pass moves - # them to the consumer. See - # https://github.com/Lightning-AI/lightning-thunder/issues/1520 + # them to the consumer. if self._use_rematerialization: fusedtrace = rematerialize(fusedtrace) @@ -1150,7 +1146,7 @@ def _pad_check(a: TensorProxy, padding_value: Number, padding_config: tuple[int, # nvFuser's pad op requires pad_widths to be a sequence of Python numbers # (lo_n, hi_n, lo_{n-1}, hi_{n-1}, ...) where dimensions are counted in reverse # as shown, and dilation is not supported. -# This is in constrant to lightning.compile's pad primitive, which specifies padding +# This is in constrast to thunder.jit's pad primitive, which specifies padding # and dilation as an ndim-length list of (lo, hi, dilation) triples. # NOTE padding_value must be an nvConstant (or nvScalar?) def pad( @@ -1256,7 +1252,8 @@ def squeeze(a: TensorProxy, /, dims: Sequence[int], *, fd: FusionDefinition, lc_ # register_supported(PrimIDs.TAKE, take, _take_check) # TAKE_ALONG_AXIS is currently disabled -# See https://github.com/NVIDIA/Fuser/issues/458 +# There was an nvFuser bug that prevented this which is now fixed; we should +# investigate re-enabling take_along_axis. # # TODO Check that the nvFuser version is >= 0.0.10 when this operator was added # def take_along_axis(a: TensorProxy, /, index: TensorProxy, dim: int, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any: # nv_a = getnv(a, fd, lc_to_nv_map) @@ -1709,12 +1706,6 @@ def div(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefinitio nva = getnv(a, fd, lc_to_nv_map) nvb = getnv(b, fd, lc_to_nv_map) - # TODO nvFuser sometimes generates an innacurate result when dividing by a number - # Remove this workaround once the issue is fixed - # See: https://github.com/NVIDIA/Fuser/issues/160 - if isinstance(b, Number): - return fd.ops.mul(nva, fd.ops.reciprocal(nvb)) - # NOTE It's currently significantly faster for nvFuser to multiply the reciprocal than divide # return fd.ops.div(nva, nvb) return fd.ops.mul(nva, fd.ops.reciprocal(nvb)) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index ab6b35abba..8f1604e718 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -57,12 +57,13 @@ def preserve_bsym(bsym: BoundSymbol) -> Any: # If the executor has an execution transform, it's called and True is returned # If no executor can execute the BoundSymbol, False is returned def visit_helper_(bsym: BoundSymbol) -> None | bool: - if bsym.sym.executor is not None or bsym.sym.python_impl is not None: + if bsym.sym.python_impl is not None: return None ex: Executor for ex in executors_list: # TODO Consider allowing operator executors to claim portions of operations + # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or ( isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) ): @@ -87,6 +88,9 @@ def visit_helper_(bsym: BoundSymbol) -> None | bool: safe_map_flat(update_swapmap, bsym.output, out) return True + if bsym.sym.executor is not None: + return None + return False def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: @@ -158,8 +162,8 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) return extrace -# NOTE: See more details for motivation in the following issue: -# https://github.com/Lightning-AI/lightning-thunder/issues/515 +# This is needed to ensure that subsymbol changes are reflected in the Python +# code generator. def _update_fusion_call_ctx(bsym: BoundSymbol) -> BoundSymbol: """Update the call_ctx information of the fusion BoundSymbol object. diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index 8c9aa4657c..a2ff22007b 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -17,8 +17,6 @@ get_grad, put_grad, put_grads, - register_augmented_forward_with_checker, - register_backward, ) from thunder.extend import OperatorExecutor, register_executor @@ -50,8 +48,8 @@ def ceil_div(a: int, b: int) -> int: def _sdpa_pad_head_dimension(a: torch.Tensor) -> torch.Tensor: head_size = a.shape[-1] - # NOTE short-circuit path when we already have compatible head_size - # See https://github.com/Lightning-AI/lightning-thunder/issues/1505 + # If the head is already a multiple of 8, then we don't need to pad. The + # pad op can be quite expensive in some cases. if head_size % 8 == 0: return a padding_size = ceil_div(head_size, 8) * 8 - head_size @@ -59,8 +57,7 @@ def _sdpa_pad_head_dimension(a: torch.Tensor) -> torch.Tensor: def _sdpa_slice_head_dimension(a: torch.Tensor, head_size: int) -> torch.Tensor: - # NOTE short-circuit path when we already have compatible head_size - # See https://github.com/Lightning-AI/lightning-thunder/issues/1505 + # ditto pad_head_dimension: the slice can be expensive, so skip if possible. if head_size % 8 == 0: return a return a[:, :, :, 0:head_size] @@ -491,8 +488,8 @@ def _scaled_dot_product_attention_fused( *, scale: None | float = None, ): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 + # Figure out which SDPA to use. There are performance cliffs to the various + # implementations, and this makes the decision cognizant of those cliffs. backend = _fused_sdp_choice(query, key, value, attn_mask, dropout_p, is_causal, scale) utils.check( @@ -530,8 +527,8 @@ def _scaled_dot_product_attention_grad( *, scale: None | float = None, ): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 + # Figure out which SDPA to use. There are performance cliffs to the various + # implementations, and this makes the decision cognizant of those cliffs. backend = _fused_sdp_choice(query, key, value, attn_mask, dropout_p, is_causal, scale) utils.check( @@ -640,8 +637,9 @@ def _fused_sdp_choice( is_causal = is_causal.value if LooseVersion(torch.__version__) < LooseVersion("2.2.0"): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 + # Figure out which SDPA to use. There are performance cliffs to the + # various implementations, and this makes the decision cognizant of + # those cliffs. backend = torch._fused_sdp_choice( fake_query, fake_key, @@ -707,112 +705,3 @@ def _scaled_dot_product_attention_checker( execution_transform=_scaled_dot_product_attention_fused, grad_transform=_scaled_dot_product_attention_grad, ) - - -def scaled_dot_product_attention_aug_fw( - query: TensorProxy, - key: TensorProxy, - value: TensorProxy, - attn_mask: TensorProxy | None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: float | None = None, -): - # NOTE Select fused sdpa using PyTorch eager mode selection behavior - # See https://github.com/Lightning-AI/lightning-thunder/issues/622 - backend = _fused_sdp_choice(query, key, value, attn_mask, dropout_p, is_causal, scale) - - utils.check( - backend != SpdaBackend.ERROR, - lambda: "Unable to find valid backend for scaled_dot_product_attention.", - ) - utils.check( - backend != SpdaBackend.MATH, - lambda: "The fallback to sdpa thunder reference is not implemented.", - exception_type=NotImplementedError, - ) - - tensor_args = (query, key, value) - scalar_args = (dropout_p, is_causal) - input_args = (*tensor_args, attn_mask, *scalar_args, scale) - if backend == SpdaBackend.FLASH_ATTENTION: - # Use flash attention kernel - (primal, *remaining_results, debug_attn_mask) = sdpfa_gradfwd(*tensor_args, *scalar_args, scale=scale) - # NOTE Remaining results contains [logsumexp, *flash_attn_only_residuals, *philox_residuals] - residuals = (*input_args, primal, *remaining_results) - return primal, residuals - elif backend == SpdaBackend.MEMORY_EFFICIENT: - # Use memory efficient kernel, which supports fp32 and attention mask arguments - (primal, logsumexp, *philox_residuals) = sdpea_gradfwd(*tensor_args, attn_mask, *scalar_args, scale=scale) - flash_attn_only_residuals = (None,) * 4 - residuals = (*input_args, primal, logsumexp, *flash_attn_only_residuals, *philox_residuals) - return primal, residuals - - -register_augmented_forward_with_checker( - sdpa_ex, - "torch.nn.functional.scaled_dot_product_attention", - _scaled_dot_product_attention_checker, - scaled_dot_product_attention_aug_fw, -) - - -@register_backward((sdpa_ex, "torch.nn.functional.scaled_dot_product_attention")) -def scaled_dot_product_attention_backward( - query: Proxy, - key: Proxy, - value: Proxy, - attn_mask: None | Proxy, - dropout_p: float, - is_causal: bool, - scale: None | float, - out: Proxy, - logsumexp: Proxy, - cum_seq_q: None | Proxy, - cum_seq_k: None | Proxy, - max_q: None | int, - max_k: None | int, - philox_seed: Proxy, - philox_offset: Proxy, - grad_out: Proxy, -): - tensor_args = (query, key, value) - scalar_args = (dropout_p, is_causal) - flash_attention_args = (cum_seq_q, cum_seq_k, max_q, max_k) - philox_args = (philox_seed, philox_offset) - use_flash_attn = all(map(lambda a: a is not None, (cum_seq_q, cum_seq_k, max_q, max_k))) - if use_flash_attn: - ( - grad_query, - grad_key, - grad_val, - ) = sdpfa_bwd( - grad_out, - *tensor_args, - out, - logsumexp, - *flash_attention_args, - *scalar_args, - *philox_args, - scale=scale, - ) - # grad_attn_mask is None since it is not supported by flash_attention kernel - return grad_query, grad_key, grad_val - else: - ( - grad_query, - grad_key, - grad_val, - grad_attn_mask, - ) = sdpea_bwd( - grad_out, - *tensor_args, - attn_mask, - out, - logsumexp, - *philox_args, - *scalar_args, - scale=scale, - ) - return grad_query, grad_key, grad_val, grad_attn_mask diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 2eb3f1f56b..93c0c2d874 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -161,7 +161,7 @@ def wrapper(*args, **kwargs): return decorator -def split_forward_backward(func, compile_data, compile_stats, /, *args, **kwargs): +def split_forward_backward(computation_trc, compile_data, compile_stats, /, *args, **kwargs): from thunder import trace from thunder.executors.passes import transform_for_execution from thunder.executors.passes import del_last_used @@ -170,6 +170,17 @@ def split_forward_backward(func, compile_data, compile_stats, /, *args, **kwargs from thunder.cudagraphs import CUDAGraphExecutor from thunder.distributed.utils import sort_waits, sort_data_parallel_syncs, sort_waits_for_zero3 from thunder.distributed.transforms import FSDPCommBucketing + from thunder.core.transforms import eval_trace + + # TODO: the trace->func->trace could likely be simplified (and look nicer) + # we cannot use python_callable() here, see the old repos 2458 + if not isinstance(computation_trc, TraceCtx): + # for the legacy codepath + func = computation_trc + else: + + def func(*args): + return eval_trace(computation_trc, *args) utils.check(compile_data is not None, lambda: "`compile_data` is required") @@ -178,6 +189,7 @@ def make_trace(func): trace(compile_data=compile_data, inline_trace=False, insert_ddp_syncs=not compile_data.using_jit), func ) + computation_trc.kwargs = {} # NOTE: This function is rather slow, so it's intended to be used # behind a cache. ba = signature(func).bind(*args, **kwargs) @@ -193,6 +205,9 @@ def make_trace(func): primal_trace = make_trace(func)(*args, **kwargs) primal_trace = sort_data_parallel_syncs(primal_trace) + if compile_stats is not None: + compile_stats.last_traces.append(primal_trace) + # torch.autograd.Function doesn't support non-flat outputs, the # grads wouldn't be propagated and backward receives None for each # non-flat non-tensor output. The output must also be a flat tuple, @@ -317,8 +332,7 @@ def make_trace(func): bw_traces.append(bw_extrace) if compile_stats is not None: - compile_stats.primal_trace = primal_trace - compile_stats.forward_last_traces = fw_traces - compile_stats.backward_last_traces = bw_traces + compile_stats.last_traces += fw_traces + compile_stats.last_backward_traces += bw_traces return fw_extrace, bw_extrace diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 9d82e12526..a48635e878 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -72,7 +72,8 @@ def torch_interpreted_func(*args): # _transform_for_operator_executor_execution implementation that need to be # fixed first. One issue is that it doesn't maintain the ssa form of the # trace, which is needed for all the passes to work correctly. - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1767 + # TODO: issue "Try using _transform_for_operator_executor_execution for + # torch.compile executor" torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs) compiled_func = torch.compile(torch_trace.python_callable()) @@ -84,13 +85,13 @@ def compiled_func_wrapper(*args): orig = getattr(torch._dynamo.eval_frame.guarded_backend_cache, "skip_backend_check_for_run_only_mode", None) try: # TODO: Remove this hack - # This is a hack to get around the fact that for some reason Dynamo - # doesn't recreate a guard for the compiled function called from the - # backward thread. This is a problem because the guard is created - # with the forward thread id, and the guard is not valid for the - # backward thread. I couldn't come up with a small repro to file an - # issue to PyTorch. - # https://github.com/pytorch/pytorch/issues/114674 + # Dynamo doesn't recreate a guard for the compiled function called + # from the backward thread. This is a problem because the guard is + # created with the forward thread ID, and the guard is not valid + # for the backward thread. + # Issue filed: https://github.com/pytorch/pytorch/issues/114674 + # We should be able to remove this hack once we're sure that the + # above fix has propagated to all supported PyTorch releases. torch._dynamo.eval_frame.guarded_backend_cache.skip_backend_check_for_run_only_mode = True return compiled_func(*args) finally: diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 85cfd9c2ce..535fddd943 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -633,7 +633,7 @@ def _stride_order_prim_impl(a: torch.Tensor, order: Sequence[int]) -> torch.Tens rsqrt = _register_torch_operation("rsqrt") # # NOTE That PyTorch's "sgn" corresponds with the "sign" primitive sgn = _register_torch_operation("sgn", like=ltorch.sign) -# # NOTE torch.sign isn't bound here because lightning.compile always uses sgn +# # NOTE torch.sign isn't bound here because thunder always uses sgn # sign = _register_torch_operation("sign") signbit = _register_torch_operation("signbit") sin = _register_torch_operation("sin") @@ -1095,8 +1095,8 @@ def _index_put_prim_transform( return index_put(a, indices, values, accumulate) -# NOTE torch.compile currently fails to compile scatter add in bfloat16 -# TODO RC1 Separate this into a torch.compile executor +# NOTE torch.compile has a compilation issue with scatter add in bfloat16, +# hence the special case here. # NOTE The scatter add transforms must set the torch language context explicitly so the .to() method # on tensors is resolved (alternatively they could explicitly call thunder.torch.to) @langctx(Languages.TORCH) @@ -1109,7 +1109,8 @@ def _scatter_add_prim_transform(a: TensorProxy, /, index: TensorProxy, value: Te return scatter_add(a, dim, index, value) -# NOTE torch.compile currently fails to compile scatter add in bfloat16 +# NOTE torch.compile has a compilation issue with scatter add in bfloat16, +# hence the special case here. @langctx(Languages.TORCH) def _scatter_add_transform(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike: # NOTE scatter_add does not participate in type promotion, so if a has the bfloat16 dtype, then so does src @@ -1243,7 +1244,7 @@ def _cross_entropy_backward_impl( ) # TODO Add support nll_loss_nd, weight tensor, and label_smoothing options. - # See https://github.com/Lightning-AI/lightning-thunder/issues/704 + # See issue "Add support for remaining cross_entropy_loss arguments." utils.check(a.ndim <= 2 and target.ndim <= 1, lambda: f"multi-dimension cross-entropy is not supported.") utils.check(weight is None, lambda: f"weight tensor argument is not supported.") @@ -1601,7 +1602,7 @@ def _unpack_prim_impl( ) -> list[torch.Tensor]: return torch._utils._unflatten_dense_tensors(buffer, tensors) - # TODO(crcrpar): Make this compatible with the coming torch_compile executor as it's doing really well for cat and reshape. + # TODO(crcrpar): Make this compatible with the torch.compile executor as it's doing really well for cat and reshape. # NOTE(crcrpar): why no caching/resue of buffer? # This prim is only used by fsdp backward for now. # Bucketing of reduce-scatter, i.e., creating a buffer for @@ -1621,7 +1622,6 @@ def _unpack_prim_impl( # To support individual copies from gradient to its bucket requires a mask or an arrayy of indices to achieve correct behavior. # In PyTorch, the op for this is [`Tensor.index_copy_`](https://pytorch.org/docs/stable/generated/torch.Tensor.index_copy_.html) where even the index tensor needs to be on the same device as ``self`` and ``tensor``. # So caching of the bucketing for fsdp backward would bloat up the memory consumption, which is the main reason this doesn't do any caching. - # See https://github.com/Lightning-AI/lightning-thunder/pull/1669/commits/a942b87e88738ce94f874c21d4adc38749ff10d7#diff-c2fd275781ba0c4aa7eec811bebb7bf0b6ca52a236b510ce7dfbb831d4d9bb40R197-R233 for the potential implementation's clumisiness. # # example of two unsharded gradients of [4, 2] and [4], world size of 4: # -------- ------ diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index 91d32e7a88..ced5d8fdb1 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -19,10 +19,6 @@ import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, CollectionProxy from thunder.core.symbol import Symbol -from thunder.core.transforms import ( - register_augmented_forward_with_checker, - register_backward, -) from thunder.extend import OperatorExecutor, register_executor __all__ = [ @@ -411,15 +407,6 @@ def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | Ten return False -register_augmented_forward_with_checker( - transformer_engine_ex, - prims.linear.id, - linear_forward_rule_checker, - linear_forwad_rule, -) - - -@register_backward((transformer_engine_ex, prims.linear.id)) def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad): return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx) @@ -429,9 +416,21 @@ def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.T return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False) +def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy: + out, saved_for_backward = linear_forwad_rule(a, w, b) + g = prims.get_grad(out) + ga, gw, gb = linear_backward_rule(*saved_for_backward, g) + prims.put_grad(a, ga) + prims.put_grad(w, gw) + if b is not None: + prims.put_grad(b, gb) + return out + + # Registers the implementation for torch.nn.functional.linear transformer_engine_ex.register_implementation( prims.linear, checker=_linear_checker, execution_transform=_linear_transform, + grad_transform=_linear_grad, ) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 2d2ce1fd2a..2acdcf6427 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -25,6 +25,7 @@ "add_always_executor", "remove_default_executor", "remove_always_executor", + "register_lookaside", ] @@ -208,6 +209,7 @@ def register_operator( module: None | type | ModuleType = None, fn: None | Callable = None, bind_postprocess: None | Callable = None, + replaces: None | Callable = None, python_printer: Callable = default_python_printer, ) -> Symbol: assert (like is None) ^ (meta is None), "Expected one and only one of 'like' and 'meta' to be specified" @@ -237,6 +239,9 @@ def _bind_postprocess(bsym: BoundSymbol) -> None: ) self.opmap[name] = sym + if replaces is not None: + register_lookaside(replaces, sym) + return sym def register_implementation( @@ -381,3 +386,10 @@ def deregister_executor(ex: Hashable | Executor) -> None: remove_always_executor(id) remove_default_executor(id) + + +def register_lookaside(function, symbol) -> None: + """register `symbol` as a lookaside for `function`""" + import thunder.core.jit_ext + + thunder.core.jit_ext._general_jit_lookaside_map[function] = thunder.core.jit_ext.interpreter_needs_wrap(symbol) diff --git a/thunder/numpy/__init__.py b/thunder/numpy/__init__.py index b53e9f26f5..ed2e6474aa 100644 --- a/thunder/numpy/__init__.py +++ b/thunder/numpy/__init__.py @@ -11,9 +11,6 @@ from thunder.core.symbol import Symbol import thunder.clang as clang -# TODO RC1 Remove this -from thunder.core.script.noinline import noinline - # # NumPy operator definitions @@ -28,7 +25,7 @@ def __init__(self, *, method_name: None | str = None): def __call__(self, fn: Callable) -> Symbol: _fn = langctx(Languages.NUMPY)(fn) - _fn = noinline(_fn) + # TODO: register _fn as opaque with the interpreter or do this in jit_ext? sym = Symbol(name=fn.__name__, meta=_fn) if self.method_name is not None: diff --git a/thunder/numpy/langctx.py b/thunder/numpy/langctx.py index 0f6e445d8f..75961b6efb 100644 --- a/thunder/numpy/langctx.py +++ b/thunder/numpy/langctx.py @@ -23,7 +23,7 @@ def has_method(self, id: str) -> bool: return id in _method_name_to_fn_map def get_method(self, id: str, *args, **kwargs) -> Callable: - # Note: concrete implmenetations should only raise AttributeError or + # Note: concrete implementations should only raise AttributeError or # return None for "missing" methods as the proxies will # route __getattr__ to here and hasattr relies on __getattr__ # throwing AttributeError (only) when the attribute does diff --git a/thunder/core/script/__init__.py b/thunder/py.typed similarity index 100% rename from thunder/core/script/__init__.py rename to thunder/py.typed diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 4ae624b6db..c321a01394 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -134,7 +134,7 @@ def _run(cls, rank, test_name, file_name, pipe): "DDP test requires CUDA and NCCL `torch.distributed` backend", ) class CompileDDPTest(DataParallelTestCase): - # Ref: https://github.com/Lightning-AI/lightning-thunder/issues/646 + # Reference issue "Add an example of DDP(compile(model)) to tests" def test_ddp_compile_module(self): model = ToyModel().to(self.rank) ddp_model = DDP(thunder.jit(model, device_ids=[self.rank])) @@ -157,7 +157,7 @@ def test_ddp_compile_module(self): last_loss = loss.detach().item() assert init_loss > last_loss - # Ref: https://github.com/Lightning-AI/lightning-thunder/issues/599 + # Reference issue "[tracker] Support DistributedDataParallel" def test_compile_ddp_module(self): model = ToyModel().to(self.rank) with self.assertRaisesRegex( @@ -452,7 +452,7 @@ def test_ddp_grad_bucketing(self, executor, bucket_size_in_mb: int): x = torch.ones((2, 12)).to(device) cm(x).mean().backward() - bwd_extrace = thunder.last_traces(cm)[1][-1] + bwd_extrace = thunder.last_backward_traces(cm)[-1] bsym_sym_id_list = [bsym.sym.id for bsym in bwd_extrace.bound_symbols] pack_syms = tuple(filter(lambda a: a == pack_prim_impl.id, bsym_sym_id_list)) unpack_syms = tuple(filter(lambda a: a == unpack_prim_impl.id, bsym_sym_id_list)) @@ -476,20 +476,24 @@ def test_rematerialize_all_gather(self): m = ToyModel().to(device) cm = thunder.jit( fsdp(m, device=device, broadcast_from=0), - interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON, ) x = torch.ones((2, 12), device=device) cm(x).mean().backward() - fwd_trc = thunder.last_traces(cm)[0][0] - bwd_trc = thunder.last_traces(cm)[1][0] + (fwd_trc,) = ( + t for t in thunder.last_traces(cm) if getattr(t.get_provenance(), "pss", "") == "Augmented forward pass" + ) + bwd_trc = thunder.last_backward_traces(cm)[0] from thunder.core.rematerialization import rematerialize_all_gather result_fwd_trc, result_bwd_trc = rematerialize_all_gather(fwd_trc, bwd_trc) # check the return statement in forward trace is updated - sharded_param_names = ("t_net1_weight", "t_net2_weight") - unshard_param_names = ("t5", "t16") + # TODO: this is not stable w.r.t. details of the processing, the sharded correspond to ("t_net1_weight", "t_net2_weight") + # in the original trace and are inputs to all_gather, the unshard are the outputs fo the corresponding wait + # If you fix this to be dynamically discerned, you'll be my hero. + sharded_param_names = ("t3", "t4") + unshard_param_names = ("t10", "t21") result_saved_for_bwd = [x.name for x in fwd_trc.bound_symbols[-1].args[1][0]] self.assertTrue(all(t not in sharded_param_names for t in result_saved_for_bwd)) self.assertTrue(all(t in result_saved_for_bwd for t in unshard_param_names)) @@ -648,7 +652,7 @@ def test_ddp_grad_parity_with_without_bucketing(self, executor): else: self.assertEqual(tuple(p.grad for p in cm.parameters() if p.grad is not None), gradients) - # TODO(crcrpar): Add torch compile to executors_list once it's available. + # TODO(crcrpar): Add torch compile to executors_list @common_utils.parametrize( "executor,bucketing_strategy,fsdptype", product( @@ -826,8 +830,8 @@ def check_inflight_allgather_number(trc, n: int, is_bucket: bool): loss.backward() # get the trace before sorting - fwd_trc = thunder.last_traces(cm)[0][-2] - bwd_trc = thunder.last_traces(cm)[1][-2] + fwd_trc = thunder.last_traces(cm)[-2] + bwd_trc = thunder.last_backward_traces(cm)[-2] from thunder.distributed.utils import limit_in_flight_allgathers @@ -1101,7 +1105,7 @@ def _test_native_ddp_helper(input_data): tdist.destroy_process_group(pg) if rank == 0: - bwd_extrace_sym_ids = [bsym.sym.id for bsym in thunder.last_traces(cmodel)[1][-1].bound_symbols] + bwd_extrace_sym_ids = [bsym.sym.id for bsym in thunder.last_backward_traces(cmodel)[-1].bound_symbols] pack_unpack_update_bucket_view_found = ( "torch_pack_prim_impl" in bwd_extrace_sym_ids and "torch_unpack_prim_impl" in bwd_extrace_sym_ids diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index ce237d9c97..dfdbeb8312 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -133,25 +133,11 @@ def executors_list(self) -> list[extend.Executor]: @singledispatchmethod def make_callable_legacy(self, fn, **kwargs): - # TODO: an error is thrown for many functions because __code__ and - # inspect.signature for wrapped functions is not matching. - # KeyError: 'args' - # thunder/core/script/frontend.py:125: KeyError - # with disable_preprocessing=False - # See: https://github.com/Lightning-AI/lightning-thunder/issues/386 - disable_preprocessing = kwargs.pop("disable_preprocessing", True) - return thunder.compile( - fn, executors_list=self.executors_list(), disable_preprocessing=disable_preprocessing, **kwargs - ) + assert kwargs.pop("disable_preprocessing", True) + return thunder.compile(fn, executors_list=self.executors_list(), disable_preprocessing=True, **kwargs) @singledispatchmethod def make_callable(self, fn, **kwargs): - # TODO: an error is thrown for many functions because __code__ and - # inspect.signature for wrapped functions is not matching. - # KeyError: 'args' - # thunder/core/script/frontend.py:125: KeyError - # with disable_preprocessing=False - # See: https://github.com/Lightning-AI/lightning-thunder/issues/386 return thunder.jit(fn, executors=self.executors_list(), **kwargs) @make_callable.register diff --git a/thunder/tests/lit_gpt_model.py b/thunder/tests/lit_gpt_model.py index 32889c6cc5..57a85089bc 100644 --- a/thunder/tests/lit_gpt_model.py +++ b/thunder/tests/lit_gpt_model.py @@ -139,7 +139,7 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> k = self.k.index_copy_(2, input_pos, k) v = self.v.index_copy_(2, input_pos, v) return k, v - # THUNDER unsupported: https://github.com/Lightning-AI/lightning-thunder/issues/1145 + # See issue: "Support more indexing operators (index_copy and index_add)" k = self.k = torch.index_add(self.k, 2, input_pos, k) v = self.v = torch.index_add(self.v, 2, input_pos, v) # THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro) diff --git a/thunder/tests/llama2_model.py b/thunder/tests/llama2_model.py index bf70e56531..b8277cd757 100644 --- a/thunder/tests/llama2_model.py +++ b/thunder/tests/llama2_model.py @@ -64,7 +64,6 @@ def apply_rotary_emb( xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting - # https://github.com/Lightning-AI/lightning-thunder/issues/1106 a, b = freqs_cos.shape freqs_cos = freqs_cos.view(1, a, 1, b) freqs_sin = freqs_sin.view(1, a, 1, b) @@ -247,7 +246,8 @@ def forward(self, tokens: torch.Tensor, targets: torch.Tensor | None = None) -> if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - # https://github.com/Lightning-AI/lightning-thunder/issues/1108 + # Workaround for issue "Unexpected KeyError when self attribute is + # set inside forward" # self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position diff --git a/thunder/tests/nanogpt_model.py b/thunder/tests/nanogpt_model.py index 128ac8c0f4..9d81b3abb8 100644 --- a/thunder/tests/nanogpt_model.py +++ b/thunder/tests/nanogpt_model.py @@ -69,8 +69,8 @@ def __init__(self, config): # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") # NOTE: The original Karpathy's script hides bias registration behind a flag - # but we don't do that here. We always register bias, because of preprocessing bug: - # https://github.com/Lightning-AI/lightning-thunder/issues/605 + # but we don't do that here. We always register bias due to a now-fixed + # bug in thunder. # TODO: Move the bias registration to be happening `if not self.flash` once the bug is fixed. # if not self.flash: # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") @@ -173,8 +173,8 @@ def __init__(self, config): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: # "UserWarning: functional_call was passed multiple values for tied weights. - # This behavior is deprecated and will be an error in future versions" - # not 100% sure what this is, so far seems to be harmless. TODO investigate + # This behavior is deprecated and will be an error in future versions". + # So far this seems to be harmless. TODO investigate self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights @@ -236,7 +236,6 @@ def forward(self, idx, targets=None): # NOTE: Advanced indexing is not yet supported in Thunder # RuntimeError: Advanced indexing currently only supports tensors as sequence elements # inference-time mini-optimization: only forward the lm_head on the very last position - # See https://github.com/Lightning-AI/lightning-thunder/issues/894 # logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim logits = self.lm_head(x) loss = None diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 7659d57fe2..17af749d37 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -770,8 +770,8 @@ def _abs_torch(x: torch.Tensor | Number): # digamma is defined for all complex numbers EXCEPT negative integers and zero digamma_opinfo = OpInfo( clang.digamma, - # NOTE: Restrict domain to avoid singularities because of - # https://github.com/Lightning-AI/lightning-thunder/issues/1138 + # NOTE: Restrict domain to avoid singularities because of issue + # "OpInfos do not use singularity_fn to produce "more stable" samples." domain=(eps, math.inf), # NOTE: digamma returns NaN for all negative integers. It returns -Inf when x = 0. singularity_fn=lambda x: torch.where(x > 0, x, (x - torch.round(x))), @@ -1079,7 +1079,7 @@ def _abs_torch(x: torch.Tensor | Number): dtypes=(datatypes.float16, datatypes.complex32), devicetypes=(devices.DeviceType.CPU,), ), - # see https://github.com/csarofeen/pytorch/issues/2367 + # Used to be an nvFuser bug here; TODO explore removing this xfail DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -1179,8 +1179,7 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.sgn), test_directives=( - # TODO Need to add nvfuser specific support for complex sign - # https://github.com/csarofeen/pytorch/issues/2492 + # TODO nvFuser needs support for complex sign DecorateInfo( pytest.mark.xfail, dtypes=(datatypes.complexfloating,), @@ -1284,7 +1283,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.tan), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1305,7 +1306,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.tanh), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1358,7 +1361,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=partial(elementwise_unary_generator, exclude_zero=True), torch_reference=_elementwise_unary_torch(torch.log), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1379,7 +1384,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=partial(elementwise_unary_generator, exclude_zero=True), torch_reference=_elementwise_unary_torch(torch.log10), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1407,14 +1414,16 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=elementwise_unary_generator, torch_reference=_elementwise_unary_torch(torch.log1p), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complexfloating,), ), - # NOTE: Torch gives wrong result: https://github.com/pytorch/pytorch/issues/94333 + # NOTE: Torch has an issue: https://github.com/pytorch/pytorch/issues/94333 DecorateInfo( pytest.mark.skip, "test_core_vs_torch_consistency", @@ -1452,7 +1461,9 @@ def _abs_torch(x: torch.Tensor | Number): sample_input_generator=partial(elementwise_unary_generator, exclude_zero=True), torch_reference=_elementwise_unary_torch(torch.log2), test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2360 + # TODO investigate nvFuser's implementation here; for complex datatypes + # nvFuser's tanh might be inaccurate, causing numerical mismatches, but + # also this concern is potentially stale in 03/2024. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", executors=("nvfuser",), dtypes=(datatypes.complex64,) ), @@ -1556,7 +1567,7 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): dtypes=(datatypes.float16,), devicetypes=(devices.DeviceType.CPU,), ), - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1444 + # TODO: we might have a tolerance issue here with relu6. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1577,7 +1588,7 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): "test_core_vs_torch_consistency", dtypes=(datatypes.bool8,), ), - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1444 + # TODO: we might have a tolerance issue here with relu6. DecorateInfo( pytest.mark.xfail(strict=False), "test_vjp_correctness", @@ -1613,7 +1624,7 @@ def selu_error_generator(op, device, dtype=torch.float32, **kwargs): datatypes.bfloat16, ), ), - # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/1444 + # TODO: we might have a tolerance issue here with relu6. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1674,8 +1685,9 @@ def selu_error_generator(op, device, dtype=torch.float32, **kwargs): devicetypes=(devices.DeviceType.CPU,), active_if=LooseVersion(torch.__version__) < "1.13", ), - # TODO: nvfuser needs to return copy for integer dtypes. - # https://github.com/csarofeen/pytorch/issues/2499 + # TODO: nvFuser does not define an integer trunc() and thus compilation + # fails. They should probably map integer trunc() to an identity op. + # Until they do, this test won't work for integer types. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -1742,7 +1754,8 @@ def elementwise_binary_generator(op, device, dtype, requires_grad, *, no_rhs_num sample_input_generator=elementwise_binary_generator, torch_reference=torch.add, test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -1810,7 +1823,8 @@ def elementwise_binary_generator(op, device, dtype, requires_grad, *, no_rhs_num sample_input_generator=elementwise_binary_generator, torch_reference=torch.copysign, test_directives=( - # See https://github.com/Lightning-AI/lightning-thunder/issues/2218 + # See issue: "flaky test: + # test_vjp_correctness_copysign_torch_cuda_float64 is flaky" DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1827,8 +1841,7 @@ def elementwise_binary_generator(op, device, dtype, requires_grad, *, no_rhs_num sample_input_generator=elementwise_comparison_generator, torch_reference=torch.eq, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1926,8 +1939,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.ge, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1969,8 +1981,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.le, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -1987,8 +1998,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.lt, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -2018,7 +2028,8 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_binary_generator, torch_reference=torch.mul, test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -2040,8 +2051,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=elementwise_comparison_generator, torch_reference=torch.ne, test_directives=( - # There's a problem of reducing a tensor produced by full op - # See https://github.com/NVIDIA/Fuser/issues/132 + # TODO: enable this; there was a now-fixed nvFuser bug causing issues. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -2066,15 +2076,15 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): pytest.mark.skip, dtypes=(datatypes.float16, datatypes.bfloat16), ), - # See https://github.com/Lightning-AI/lightning-thunder/issues/972 - # PyTorch's nextafter may be causing CUDA illegal memory accesses + # TODO There was an issue with nextafter in PyTorch that should now be + # resolved; re-enable this and test. DecorateInfo( pytest.mark.skip, "test_core_vs_torch_consistency", devicetypes=(devices.DeviceType.CUDA,), ), - # See https://github.com/Lightning-AI/lightning-thunder/issues/972 - # PyTorch's nextafter may be causing CUDA illegal memory accesses + # TODO There was an issue with nextafter in PyTorch that should now be + # resolved; re-enable this and test. DecorateInfo( pytest.mark.skip, executors=("torch",), @@ -2101,8 +2111,8 @@ def polygamma_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs polygamma_opinfo = OpInfo( ltorch.polygamma, - # NOTE: Restrict domain to avoid singularities because of - # https://github.com/Lightning-AI/lightning-thunder/issues/1138 + # NOTE: Restrict domain to avoid singularities. See issue "OpInfos do not + # use singularity_fn to produce "more stable" samples" # NOTE: polygamma returns NaN, -Inf, or Inf for all negative integers. domain=(eps, math.inf), sample_input_generator=polygamma_sample_input_generator, @@ -2155,7 +2165,7 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe "test_core_vs_torch_consistency", dtypes=(datatypes.complex32,), ), - # See https://github.com/csarofeen/pytorch/issues/2361 + # TODO For complex numbers we have some numerical consistency issues. DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -2241,7 +2251,8 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe ), # torch doesn't support bool true_divide DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.bool8,)), - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_vjp_correctness", @@ -2384,7 +2395,8 @@ def addcmul_addcdiv_sample_generator(op, device, dtype, requires_grad, **kwargs) "test_core_vs_torch_consistency", dtypes=(datatypes.exact,), ), - # This test is flaky, see https://github.com/Lightning-AI/lightning-thunder/issues/2244 + # See issue "flaky test: + # test_vjp_correctness_addcdiv_nvfuser_cuda_float64" DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -2510,8 +2522,7 @@ def clamp_sample_generator(op, device, dtype, requires_grad, **kwargs): torch_reference=torch.clamp, dtypes=(datatypes.signedinteger, datatypes.unsignedinteger, datatypes.floating), test_directives=( - # This test is flaky - # See https://github.com/Lightning-AI/lightning-thunder/issues/1992 + # see issue "test_vjp_correctness_clamp_nvfuser_cuda_float64 is flaky" DecorateInfo( pytest.mark.skip, "test_vjp_correctness", @@ -2715,7 +2726,8 @@ def broadcast_in_dim_error_generator(op, device, **kwargs): pytest.mark.xfail, "test_errors", ), - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the number of + # non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -2831,6 +2843,11 @@ def diagonal_sample_generator(op, device, dtype, requires_grad, **kwargs): ltorch.diagonal, sample_input_generator=diagonal_sample_generator, torch_reference=torch.diagonal, + test_directives=( + # thunder.torch.diagonal meta function is not correctly implemented for + # input case ((1, 2, 0, 3), -1, 0, -1) + DecorateInfo(pytest.mark.xfail(strict=True), "test_vjp_correctness"), + ), ) shape_ops.append(diagonal_opinfo) @@ -3214,8 +3231,7 @@ def pad_sample_generator(op, device, dtype, requires_grad, **kwargs): # Versions of above examples but with padding between elements set to 0 ((2, 2), ((1, 1, 0), (-1, 2, 0))), ((2, 0, 3), ((1, 0, 0), (1, 1, 0), (0, 0, 0))), - # See https://github.com/Lightning-AI/lightning-thunder/issues/415 - # The PyTorch lowering does not handle this case properly + # See issue "PyTorch pad prim lowering handles out-of-bands negative padding incorrectly" # ((7, 5), ((0, 0, 0), (-6, 2, 0))), ((5, 7), ((0, 0, 0), (-6, 2, 0))), ((3, 2, 5), ((-2, 1, 0), (1, -1, 0), (-1, 3, 0))), # negative pad in all 3 dims @@ -3242,12 +3258,12 @@ def _jax_pad(a, padding_value, padding_config): executors=("torch",), dtypes=(datatypes.complexfloating,), ), - # See issue https://github.com/Lightning-AI/lightning-thunder/issues/2053 + # See issue "pad+nvFuser: wrong results when applied to 1-numel inputs" DecorateInfo( pytest.mark.xfail, executors=("nvfuser",), ), - # See issue https://github.com/Lightning-AI/lightning-thunder/issues/2053 + # See issue "pad+nvFuser: wrong results when applied to 1-numel inputs" DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -3318,7 +3334,7 @@ def pad_torch_error_generator(op, device, dtype=torch.float32, **kwargs): # TODO: only remove these cases when the executor is nvfuser -# FIXME: Zero-dim cases are skipped due to https://github.com/csarofeen/pytorch/issues/2383 +# TODO: zero-dim cases had a bug, now fixed; re-enable. # FIXME: tensors with no elements are skipped because of no nvfuser support def reshape_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3438,8 +3454,8 @@ def slice_in_dim_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=slice_in_dim_sample_generator, jax_reference=jax.lax.slice_in_dim if JAX_AVAILABLE else None, test_directives=( - # nvfuser executor doesn't support pad correctly - # See https://github.com/Lightning-AI/lightning-thunder/issues/285 + # TODO: nvfuser executor didn't support pad correctly, but now it should. + # Test and re-enable. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -3450,8 +3466,8 @@ def slice_in_dim_sample_generator(op, device, dtype, requires_grad, **kwargs): shape_ops.append(slice_in_dim) -# TODO https://github.com/Lightning-AI/lightning-thunder/issues/416 -# Add strides and slicing outside tensor boundaries +# See issue "Slice prim samples need strides and slicing beyond tensor +# boundaries" def slice_prim_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3473,8 +3489,8 @@ def slice_prim_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=slice_prim_sample_generator, jax_reference=jax.lax.slice if JAX_AVAILABLE else None, test_directives=( - # nvfuser executor doesn't support pad correctly - # See https://github.com/Lightning-AI/lightning-thunder/issues/285 + # TODO: nvfuser executor didn't support pad correctly, but now it should. + # Test and re-enable. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -3535,7 +3551,7 @@ def split_sample_generator(op, device, dtype, requires_grad, **kwargs): ((4, 6, 7), 3, -1), ((4, 6, 7), 9, 1), ((4, 6, 7), (1, 2, 1, 2), 1), - # TODO https://github.com/Lightning-AI/lightning-thunder/issues/420 + # See issue "nvFuser split test failure" # ((4, 6, 7), (3, 1, 2, 0, 0, 1), -1), ((4, 4, 12), 4, 2), ) @@ -3760,8 +3776,8 @@ def tensor_split_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=tensor_split_sample_generator, torch_reference=torch.tensor_split, test_directives=( - # nvfuser executor doesn't support pad correctly - # See https://github.com/Lightning-AI/lightning-thunder/issues/285 + # TODO: nvfuser executor didn't support pad correctly, but now it should. + # Test and re-enable. DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -3888,6 +3904,72 @@ def torch_permute_reference(a, *dims): shape_ops.append(permute_opinfo) +def t_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape + cases = ( + (), + (1), + (4), + (4, 5), + ) + + for shape in cases: + yield SampleInput(make(shape)) + + +def t_error_generator(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + + # shape, error type, error message + cases = ( + ((4, 5, 6), RuntimeError, r"t\(\) expects a tensor with <= 2 dimensions, but self is 3D"), + ( + (4, 5, 6, 7), + RuntimeError, + r"t\(\) expects a tensor with <= 2 dimensions, but self is 4D", + ), + ) + + for shape, err_type, err_msg in cases: + yield SampleInput(make(shape)), err_type, err_msg + + +t_opinfo = OpInfo( + ltorch.t, + sample_input_generator=t_sample_generator, + error_input_generator=t_error_generator, + torch_reference=lambda x: torch.Tensor.t(x), +) +shape_ops.append(t_opinfo) + + +def reverse_dims_T_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape + cases = ( + (), + (1), + (4), + (4, 5), + (4, 5, 6), + (4, 5, 6, 7), + ) + + for shape in cases: + yield SampleInput(make(shape)) + + +reverse_dims_T_opinfo = OpInfo( + ltorch.reverse_dims_T, + sample_input_generator=reverse_dims_T_sample_generator, + torch_reference=lambda x: x.T, +) +shape_ops.append(reverse_dims_T_opinfo) + + def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -4183,7 +4265,8 @@ def unsqueeze_sample_generator(op, device, dtype, requires_grad, **kwargs): sample_input_generator=unsqueeze_sample_generator, jax_reference=jax.lax.expand_dims if JAX_AVAILABLE else None, test_directives=( - # See https://github.com/csarofeen/pytorch/issues/2549 + # See issue "broadcast_in_dim: The size of contiguity must equal to the + # number of non-broadcasting IterDomains" DecorateInfo( pytest.mark.skip, "test_jvp_correctness", @@ -4398,7 +4481,8 @@ def _replace_random_percentage(a: torch.Tensor, value: Number, percentage: float dtypes=(datatypes.complex32,), devicetypes=(devices.DeviceType.CPU,), ), - # See https://github.com/csarofeen/pytorch/issues/2369 + # nvFuser had issues with complex reductions, now fixed; TODO re-enable + # this test. DecorateInfo( pytest.mark.xfail, dtypes=(datatypes.complexfloating,), @@ -4448,7 +4532,8 @@ def var_sample_generator(op, device, dtype, requires_grad): dtypes=(datatypes.complex32,), devicetypes=(devices.DeviceType.CPU, devices.DeviceType.CUDA), ), - # See https://github.com/csarofeen/pytorch/issues/2369 + # nvFuser had issues with complex reductions, now fixed; TODO re-enable + # this test. DecorateInfo( pytest.mark.xfail, dtypes=(datatypes.complexfloating,), @@ -4498,8 +4583,7 @@ def var_sample_generator(op, device, dtype, requires_grad): # Complex var is not supported yet dtypes=(datatypes.floating,), test_directives=( - # TODO FIXME nvFuser fails to compile var_mean for these tests - # See https://github.com/Lightning-AI/lightning-thunder/issues/1438 + # See issue "nvFuser fails to compile some var_mean tests" DecorateInfo( pytest.mark.xfail, "test_core_vs_torch_consistency", @@ -5145,7 +5229,8 @@ def einsum_error_generator(op, device, **kwargs): supports_grad=True, # TODO: test all integer types and figure out their dtype. dtypes=(datatypes.float32, datatypes.float64), - # See https://github.com/Lightning-AI/lightning-thunder/issues/1643. + # See issue "Disabled einsum tests might hide potential issues in our + # testing/op implementations" # Testing only float32, float64 now. # types=(datatypes.int64, datatypes.floating), # domain=(-1, +1), @@ -6023,7 +6108,7 @@ def group_norm_error_generator(op, device, **kwargs): dtypes=(datatypes.float16, datatypes.bfloat16), devicetypes=(devices.DeviceType.CUDA,), ), - # See https://github.com/Lightning-AI/lightning-thunder/issues/1405 + # This should be fixed now; TODO re-enable, test DecorateInfo( pytest.mark.xfail, executors=("nvfuser",), @@ -6411,7 +6496,8 @@ def embedding_sample_generator(op, device, dtype, requires_grad, **kwargs): dtypes=(datatypes.floating, datatypes.complexfloating), test_directives=( # TODO Investigate these discrepancies -- some dtype x executor configurations seem to be fine - # See https://github.com/Lightning-AI/lightning-thunder/issues/1387 + # See issue "phantom grad's embedding computation is divergent from + # PyTorch's" DecorateInfo( custom_comparator(partial(assert_close, atol=1, rtol=2)), "test_phantom_grad_vs_torch_consistency", @@ -6749,9 +6835,8 @@ def cross_entropy_reference_generator(op, device, dtype, requires_grad, **kwargs # TODO Enable cross entropy bwd weight support -# see https://github.com/Lightning-AI/lightning-thunder/issues/834 # TODO Enable test cases after adding support nll_loss_nd, weight tensor, and label_smoothing options. -# See https://github.com/Lightning-AI/lightning-thunder/issues/704 +# TODO see issue "Add support for remaining cross_entropy_loss arguments" def cross_entropy_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -7098,7 +7183,7 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs): dtypes=(datatypes.float16,), devicetypes=(devices.DeviceType.CPU,), ), - # https://github.com/Lightning-AI/lightning-thunder/issues/1032 + # This should be fixed now; TODO re-enable and test DecorateInfo( pytest.mark.xfail, "test_vjp_correctness", @@ -7115,7 +7200,8 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs): prob_distr_ops = [] -# multinomial testing is currently disabled due to https://github.com/Lightning-AI/lightning-thunder/issues/2258 +# multinomial testing is currently disabled due to issue "randomness: enable +# PyTorch generators for operations like multinomial" # def multinomial_sample_generator(op, device, dtype, requires_grad, **kwargs): # make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index f7dfbd1267..54b27f4964 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -532,7 +532,7 @@ def foo(tup): @instantiate(dtypes=NOTHING) def test_type_promotion_tensors(executor, device, _): if executor == TorchExecutor: - pytest.xfail("https://github.com/Lightning-AI/lightning-thunder/issues/406") + pytest.xfail('see issue "vmap of sum doesn\'t work when dims are passed as a keyword argument"') def foo(a, b): return a + b @@ -590,7 +590,7 @@ def bar(a, b, c): @instantiate(dtypes=NOTHING) def test_type_promotion_numbers_and_tensors(executor, device, _): if executor == TorchExecutor: - pytest.xfail("https://github.com/Lightning-AI/lightning-thunder/issues/406") + pytest.xfail('See issue "Type promotion with the torchexecutor and elementwise operations is incorrect"') def foo(a, b, c): return a + b + c @@ -1130,7 +1130,8 @@ def test_detached_trace(executor, device: str, _): def test_normalized_args_prims_sum(executor, device: str, dtype: dtypes.dtype): # This test verifies that the recorded trace for a call to prims.sum # has its positional and keyword arguments normalized to the same form. - # See: https://github.com/Lightning-AI/lightning-thunder/issues/195 + # See issue "vmap of sum doesn't work when dims are passed as a keyword + # argument" a = make_tensor((2, 2), device=device, dtype=ltorch.to_torch_dtype(dtype)) def func_dim_posarg(x): @@ -1221,7 +1222,8 @@ def foo(x): assert str(trace).count("Testing") == 1 -# Check for https://github.com/Lightning-AI/lightning-thunder/issues/471 +# Check to verify the issue in "KeyError thrown in thunder.executor.utils.Region +# when None is passed in as input". @instantiate(dtypes=(thunder.float32,)) def test_argument_of_none(executor, device, dtype): from thunder.executors.utils import Region @@ -1683,7 +1685,7 @@ def test_transforms_vmap_axis_size(executor, device, _): @instantiate( dtypes=NOTHING, - decorators=(pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2118"),), + decorators=(pytest.mark.xfail(reason='issue "flaky test: test_transforms_vjp_{2_1, 1_2}_nvfuser_cuda_None"'),), ) def test_transforms_vjp_1_2(executor, device, _): from thunder.core.transforms import vjp @@ -1790,7 +1792,7 @@ def func(x): @instantiate( dtypes=NOTHING, - decorators=(pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2118"),), + decorators=(pytest.mark.xfail(reason='issue "flaky test: test_transforms_vjp_{2_1, 1_2}_nvfuser_cuda_None"'),), ) def test_transforms_vjp_2_1(executor, device, _): from thunder.core.transforms import vjp @@ -1830,7 +1832,8 @@ def func_2_1(x, y): # executors=( # nvFuserExecutor, # # TODO: Enable Torch executor once the issue with sum is fixed -# # See: https://github.com/Lightning-AI/lightning-thunder/issues/438 +# # See issue "Different behavior of sum(tensor, ()) for nvFuser and +# # Torch executor" # ), # ) # def test_transforms_vmap_inline_value_and_grad(executor, device, _): @@ -1950,8 +1953,8 @@ def f(a): assert "thunder.computation" in excinfo.traceback[-1].path -# TODO Add nvFuser support (https://github.com/Lightning-AI/lightning-thunder/issues/809) -# TODO Make these OpInfo tests (https://github.com/Lightning-AI/lightning-thunder/issues/810) +# TODO See issue "Add contiguous and clang.stride_order OpInfos that check stride +# consistency with PyTorch" @instantiate( dtypes=NOTHING, executors=(TorchExecutor,), @@ -2191,7 +2194,8 @@ def func(qkv): @instantiate(dtypes=NOTHING) def test_no_passthrough_symbol(executor, device, _): # A test case for the situation reported in - # https://github.com/Lightning-AI/lightning-thunder/issues/1131 + # "backward trace contains symbols not present in forward that cause + # NotImplementedError" # When an operation simply passes through its input, we should not # add it to the trace. diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index f90d2a231c..4128e02914 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -29,46 +29,47 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req from thunder.tests.opinfos import SampleInput # TODO: cudnnex seems to produce large mismatches against reference when tensor initialized from the wider default range of [-9,9] - # https://github.com/Lightning-AI/lightning-thunder/issues/1871 + # See issue "cuDNN SDPA backward might return NaNs for inputs with absolute + # value more than certain threshold" make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-0.5, high=0.5) n_head = 2 - N = 8 + N = 8 # batch size # TODO: multiple of 8 seems to produce NaNs - L = random.randint(1, 10) * 64 + L = random.randint(1, 10) * 64 # query's sequence length alignment_factor = 8 - S = random.randint(1, 10) * alignment_factor - E = random.randint(8, 16) * alignment_factor - Ev = random.randint(8, 16) * alignment_factor + S = random.randint(1, 10) * alignment_factor # key/value's sequence length + E = random.randint(8, 16) * alignment_factor # query/key's embedding size + Ev = random.randint(8, 16) * alignment_factor # value's embedding size # 4-dim (multiheaded) causal cases q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) - yield SampleInput(q, k, v, attn_mask := None, dropout_p := 0.0, is_causal := True) + yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True) # TODO: cudnnex seems to have a few mismatches. Will be enabled in a later PR. # Non-contiguous input tensor case nq = make(N, n_head, L, E).permute(0, 1, 3, 2) nk = make(N, n_head, L, E).permute(0, 1, 3, 2) nv = make(N, n_head, L, E).permute(0, 1, 3, 2) - yield SampleInput(nq, nk, nv, attn_mask := None, dropout_p := 0.0, is_causal := False) + yield SampleInput(nq, nk, nv, None, dropout_p=0.0, is_causal=False) # Test the scale factor which was added in torch 2.1 if LooseVersion(torch.__version__) >= LooseVersion("2.1.0"): q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) - yield SampleInput(q, k, v, attn_mask := None, dropout_p := 0.0, is_causal := False, scale=0.123) + yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=False, scale=0.123) # TODO: cudnnex only support of grad_attn_mask with batch dim 1 and both sequence lenghts divisible by 64. Release 9.0.1 will relax this constraint. # Additive attn_mask q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) additive_attn_mask = make((1, n_head, L, S), dtype=q.dtype).tril() - yield SampleInput(q, k, v, attn_mask := additive_attn_mask, is_causal=False) + yield SampleInput(q, k, v, additive_attn_mask, is_causal=False) # Boolean attn_mask q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) bool_attn_mask = make((1, n_head, L, S), dtype=torch.bool, low=1, high=1, requires_grad=False).tril() - yield SampleInput(q, k, v, attn_mask := bool_attn_mask, is_causal=False) + yield SampleInput(q, k, v, bool_attn_mask, is_causal=False) grad_sdpa_cudnn_opinfo = OpInfo( @@ -87,10 +88,6 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req ) -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880 -# NOTE This test modifies the global executor map, so it technically should not -# be run in parallel with other tests @requiresCUDA def test_cudnn_sdpa(): # expect sdpa to fail for 8.9.2 and below @@ -113,10 +110,8 @@ def test_cudnn_sdpa(): query = 1 * (torch.randn(shape_Q, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) key = 2 * (torch.randn(shape_K, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) value = 3 * (torch.randn(shape_V, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) - is_causal = False - attn_mask = torch.randn( - s_q, s_kv, requires_grad=False, device="cuda", dtype=thunder.torch.to_torch_dtype(dtype) - ) + is_causal = True + attn_mask = None expected = torch.nn.functional.scaled_dot_product_attention( query, key, value, is_causal=is_causal, attn_mask=attn_mask @@ -127,7 +122,7 @@ def test(query, key, value, is_causal=False, attn_mask=None): query, key, value, is_causal=is_causal, attn_mask=attn_mask ) - ctest = thunder.compile(test, executors_list=[cudnn_ex]) + ctest = thunder.jit(test, executors=[cudnn_ex]) actual = ctest(query, key, value, is_causal=is_causal, attn_mask=attn_mask) torch.testing.assert_close(actual, expected, atol=2e-2, rtol=1e-2) last_trace = thunder.last_traces(ctest)[-1] @@ -157,8 +152,6 @@ def snippet_torch_consistency(op, torch_op, sample): assert_close(thunder_result, torch_result, equal_nan=True, atol=0.0625, rtol=5e-2) -# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n -# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880 # TODO Make it easier for executors to write tests like this, including writing them out-of-tree # TODO The executor passed below is just a "dummy" that actually gets ignored -- we should provide # a way to use decorators like @ops without a particular executor @@ -184,7 +177,7 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): pytest.xfail("Only interleaved layout is supported pre 8.9.2.") for sample in op.reference_inputs(device, dtype, requires_grad=False): - cfn = thunder.compile(op_name_to_fn[op.name], executors_list=[cudnn_ex, cudnn_layernorm_ex]) + cfn = thunder.jit(op_name_to_fn[op.name], executors=[cudnn_ex, cudnn_layernorm_ex]) result = run_snippet( snippet_torch_consistency, @@ -207,17 +200,19 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): supported_devicetypes=(devices.DeviceType.CUDA,), ) def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): - ran_atleast_one = False for sample in op.reference_inputs(device, dtype, requires_grad=True): - from thunder.executors.cudnnex import cudnn_ex - # Enforce tensor arguments are contiguous for torch reference contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args)) # query, key, value grad_inputs = list(contiguous_args[:3]) - if (attn_mask := sample.args[3]) is not None and attn_mask.requires_grad: - grad_inputs.append(attn_mask) + if (attn_mask := sample.args[3]) is not None: + if attn_mask.requires_grad: + grad_inputs.append(attn_mask) + # TODO(#2470): With cudnn frontend 1.1 and A100, this test hits + # RuntimeError when `attn_mask` is provided: `[cudnn_frontend] + # Error: No execution plans built successfully`. + continue # Compute vjp result using PyTorch expect_out = op.torch_reference(*contiguous_args, **sample.kwargs) @@ -235,17 +230,10 @@ def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): executors_list=executor.executors_list() + [cudnn_ex], ) - try: - actual_out, actual_grad = cfoo(filtered_args, (v,)) - except Exception as e: - continue + actual_out, actual_grad = cfoo(filtered_args, (v,)) comp(actual_out, expect_out, atol=1e-2, rtol=1e-2) # compare gradients of query, key, value, and attn_mask for eg, ag in zip(expected_grad, actual_grad): comp(eg, ag, atol=2e-1, rtol=2e-2) - - ran_atleast_one = True - - assert ran_atleast_one == True diff --git a/thunder/tests/test_elementwise.py b/thunder/tests/test_elementwise.py index a76db9e673..af5dcfb58b 100644 --- a/thunder/tests/test_elementwise.py +++ b/thunder/tests/test_elementwise.py @@ -25,7 +25,7 @@ def test_elementwise_dunder_operations_on_numbers(executor, device, dtype): # (math.floor, (bool, int, float)), # (operator.inv, (bool, int)), # (operator.neg, (bool, int, float, complex)), - # # TODO https://github.com/Lightning-AI/lightning-thunder/issues/713 + # # TODO see issue "Implement positive operations" # # operator.pos, # (builtins.round, (bool, int, float)), # (math.trunc, (bool, int, float)), @@ -65,8 +65,8 @@ def foo(a): assert_close(actual, expected) -# TODO Test operator and method variants using OpInfos -# See https://github.com/Lightning-AI/lightning-thunder/issues/710 +# TODO: see issue "Test operator and method variants of operations using +# OpInfos" @instantiate(dtypes=(thunder.float32,)) def test_core_tensor_methods(executor, device, dtype): def foo(a, b, c, d): @@ -126,7 +126,8 @@ def test_where(executor, device, dtype): torch_result = torch_fn(pred, i64, -2.3) assert_close(thunder_result, torch_result) - # TODO Fix https://github.com/Lightning-AI/lightning-thunder/issues/711 + # TODO fix issue "Currently nvFuser tensor x float operations result in + # float64 results" # float x int # thunder_result = thunder_fn(pred, 3., 5) # torch_result = torch_fn(pred, 3., 5) diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 1d01bef894..13bd05360b 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -63,12 +63,13 @@ def bar(a, b): # [4] [2,2] with runtime_allocated_memory(device): cbar(a, b) - traces = thunder.last_traces(cbar) - fwd_extrace = traces[0][-1] + fw_traces = thunder.last_traces(cbar) + fwd_extrace = fw_traces[-1] max_mem_fwd = get_alloc_memory(fwd_extrace) assert max_mem_fwd[0] == 144 assert sum(max_mem_fwd[1].values()) == get_return_memory(fwd_extrace.bound_symbols[-1]) # 144 - bw_extrace = traces[1][-1] + bw_traces = thunder.last_backward_traces(cbar) + bw_extrace = bw_traces[-1] max_mem_bw = get_alloc_memory(bw_extrace) assert max_mem_bw[0] == 144 assert sum(max_mem_bw[1].values()) == get_return_memory(bw_extrace.bound_symbols[-1]) # 32 @@ -137,9 +138,8 @@ def test_nanogpt_block(executor, device, dtype): result = cblock(inp) with runtime_allocated_memory(device): result.backward(torch.ones_like(result)) - traces = thunder.last_traces(cblock) - fw_extrace = traces[0][-1] - bw_extrace = traces[1][-1] + fw_extrace = thunder.last_traces(cblock)[-1] + bw_extrace = thunder.last_backward_traces(cblock)[-1] fw_alloc_mem = get_alloc_memory(fw_extrace) bw_alloc_mem = get_alloc_memory(bw_extrace) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 9e77f8f496..b7dd300d12 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -131,3 +131,57 @@ def test_get_all_executors_includes_all_native_executors(): if torch.cuda.is_available(): expected.update({"nvfuser"}) assert actual == expected + + +def test_register_implementation_custom_op(): + myex = OperatorExecutor("myex", version="0.1") + register_executor(myex) + + def official_add(a, b): + return a + b + + def _myadd(a, b): + return a + b + + myadd1 = myex.register_operator("myadd1", like=_myadd, fn=_myadd, replaces=official_add) + myadd2 = myex.register_operator("myadd2", like=_myadd, fn=_myadd) + + def fn(a, b): + return official_add(a, b) + + cfn = thunder.jit(fn, executors=[myex]) + + a = torch.randn(2, 2) + b = torch.randn(2, 2) + + res = cfn(a, b) + + assert "myadd1" in str(thunder.last_traces(cfn)[-1]) + + def myadd_trafo(a, b): + return myadd2(a, b) + + def myadd_grad_trafo(a, b): + res = myadd2(a, b) + grad_res = get_grad(res) + put_grads((a, b), (grad_res, grad_res)) + return res + + myex.register_implementation(myadd1, execution_transform=myadd_trafo, grad_transform=myadd_grad_trafo) + + cfn = thunder.jit(fn, executors=[myex]) + res = cfn(a, b) + + s = str(thunder.last_traces(cfn)[-1]) + assert "myadd2" in s and "myadd1" not in s + + a.requires_grad_() + + res = cfn(a, b) + + s = str(thunder.last_traces(cfn)[-1]) + assert "myadd2" in s and "myadd1" not in s + + a.requires_grad_() + + deregister_executor(myex) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 931c365fb9..3a44a25dbc 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -26,7 +26,7 @@ # TODO: Move this to thunder.tests.opinfos op_skip = { - # See https://github.com/Lightning-AI/lightning-thunder/issues/226 + # See issue "Support closures of torch.Tensor" # TODO: AttributeError: 'Tensor' object has no attribute 'true_dtype' "masked_fill", # TODO: RuntimeError: Expected index=tensor([2, 3, 2, 0, 3, 1, 0, 2], @@ -514,7 +514,7 @@ def test_vjp_correctness_sdpa_manual(op, device, dtype, executor, comp): vjp(filtered_op), disable_torch_autograd_support=True, disable_preprocessing=True, - executors_list=executor.executors_list() + [sdpa_ex], + executors_list=[sdpa_ex, *executor.executors_list()], )(filtered_args, (v,)) comp(actual_out, expect_out) @@ -635,7 +635,8 @@ def bar(a, b): dtypes=NOTHING, ) def test_convert_element_type_with_float(executor, device, _): - # Verifies a fix for https://github.com/Lightning-AI/lightning-thunder/issues/537 + # Verifies the fix for "grad transform hits error: AttributeError: 'float' + # object has no attribute 'dtype'" from thunder.core.transforms import value_and_grad a = make_tensor([5], dtype=torch.float32, device=device) @@ -708,7 +709,9 @@ def sincos_backward(sin_x, cos_x, g1, g2): assert trace.output[0] == trace.bound_symbols[4].output -# TODO: Fix flaky test https://github.com/Lightning-AI/lightning-thunder/issues/1919 +# TODO: see issue +# "thunder/tests/test_grad.py::test_torch_autograd_saved_tensors_memory_release +# is flaky" @pytest.mark.xfail(strict=False, reason="This test is flaky") @requiresCUDA def test_torch_autograd_saved_tensors_memory_release(): @@ -822,6 +825,42 @@ def fun_bw(a, b, g): torch.testing.assert_close(actual_bw, expected_bw) +@instantiate( + dtypes=NOTHING, +) +def test_make_aug_forward_and_backward_var_mean(executor, device, _): + # This test checks that the split of the joint forward/backward function for + # var_mean correctly puts the forward part into the augmented forward + # function and the backward part into the backward function without + # overlapping symbols. + from thunder.core.vjp_utils import make_aug_forward_and_backward + from thunder.core.prims import var_mean + + def fun(a): + return var_mean(a, (0,), correction=1) + + x = torch.tensor((2, 2), device=device, dtype=torch.float32) + + trace = thunder.trace()(fun, x) + var_mean_bsym = trace.bound_symbols[-2] + assert var_mean_bsym.sym.name == "var_mean" + + aug_fw, bw = make_aug_forward_and_backward(var_mean_bsym) + aug_fw = executor.make_callable(aug_fw) + out, saved = aug_fw(x, (0,), correction=1) + bw = executor.make_callable(bw) + _ = bw(*saved, *out) + bw_trace = thunder.last_traces(bw)[0] + assert "var_mean" not in (s.sym.name for s in bw_trace.bound_symbols) + + +def test_no_duplicate_backward_registered(): + from thunder.core.transforms import backward_impls, _grad_fn_map + + same_keys = set(_grad_fn_map.keys()).intersection(set(backward_impls.keys())) + assert not same_keys, f"Duplicate keys: {same_keys}" + + @instantiate( dtypes=NOTHING, ) @@ -870,8 +909,7 @@ def func(a): def test_torch_autograd_crazy_collections_in_and_out(executor, device, dtype): from thunder.executors.torch_autograd import thunder_backward - # Borrowed from - # https://github.com/Lightning-AI/lightning-thunder/blob/3401475ee47d5a732b6b4d5dcbd88afcd9bed81d/thunder/tests/test_core.py#L117 + # Borrowed from `test_crazy_collections_in_and_out`. def foo(a, b, c, *, ka, kb, kc): d = { 5: 2, @@ -971,19 +1009,16 @@ def test_torch_autograd_module_get_compile_stats(executor, device, _): out.backward(g) compile_stats = compile_stats(lc) - primal_trace = compile_stats.primal_trace - forward_traces = compile_stats.forward_last_traces - backward_traces = compile_stats.backward_last_traces + forward_traces = compile_stats.last_traces + backward_traces = compile_stats.last_backward_traces assert isinstance(forward_traces, list) assert len(forward_traces) >= 1 assert isinstance(backward_traces, list) assert len(backward_traces) >= 1 - assert isinstance(primal_trace, TraceCtx) - fw_bw_traces = thunder.last_traces(lc) - assert isinstance(fw_bw_traces, tuple) - assert len(fw_bw_traces) == 2 - assert fw_bw_traces[0] == forward_traces - assert fw_bw_traces[1] == backward_traces + fw_traces = thunder.last_traces(lc) + bw_traces = thunder.last_backward_traces(lc) + assert fw_traces == forward_traces + assert bw_traces == backward_traces @instantiate( @@ -1250,11 +1285,11 @@ def test_populate_grads_mlp(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) tom_grad = grad(tom) thunder_grads = tom_grad(x) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=(x,)) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-3, rtol=1e-5) @@ -1277,11 +1312,11 @@ def test_populate_grads_csa(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) tom_grad = grad(tom) thunder_grads = tom_grad(x) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=[x]) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-2, rtol=1e-2) @@ -1304,11 +1339,11 @@ def test_populate_grads_block(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) tom_grad = grad(tom) thunder_grads = tom_grad(x) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=[x]) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-2, rtol=1e-2) @@ -1340,7 +1375,7 @@ def test_populate_grads_nanogpt(executor, device, dtype): clear_grads(model) - tom = executor.make_callable_legacy(model, disable_preprocessing=False) + tom = executor.make_callable(model) def grad_specifier(out) -> None: logits, loss = out @@ -1349,7 +1384,7 @@ def grad_specifier(out) -> None: tom_grad = grad(tom, grad_specifier=grad_specifier) thunder_grads = tom_grad(x, targets) - populate_grads(thunder_grads, tom) + populate_grads(thunder_grads, tom, args=[x, targets]) thunder_grads = extract_grads(tom) assert_close(torch_grads, thunder_grads, atol=1e-2, rtol=1e-2) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index e113438f93..447647b8bb 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1010,7 +1010,8 @@ def foo(x): # } -# See https://github.com/Lightning-AI/lightning-thunder/issues/2078 +# Test for issue "jit: passing jitted functions as arguments to jitted +# functions fails." def test_reduce_jitted_reduce_fn(jit): import functools @@ -1482,7 +1483,7 @@ def foo(): assert jfoo() is True -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/1824") +@pytest.mark.xfail(reason='"exec() and eval() lookaside ignores locals()"') def test_exec_import_star(jit): # Assert that we can actually generate the instruction to_exec = "from itertools import *" @@ -2606,8 +2607,8 @@ def test_displayhook(jit): import io import code - # TODO: Implement the lookaside for exec(). Under he hood, `code.InteractiveInterpreter().runsource('5;6;7')`` - # just compiles the string and calls exec(), plus a little bit of irrelevant error handling. + # TODO: Implement the lookaside for exec(). Under the hood, `code.InteractiveInterpreter().runsource('5;6;7')`` + # just compiles the string and calls exec(), plus a little bit of error handling. # I'm not entirely convinced that the PRINT_EVAL is going through our system at the moment, but # it for sure would with an exec() lookaside. I'm also not sure what makes InteractiveInterpreter # interactive. It isn't *actually* in interactive mode. So, why is PRINT_EXPR in the interpreted @@ -2616,7 +2617,7 @@ def test_displayhook(jit): py_redirect = io.StringIO() with redirect_stdout(py_redirect): # Avoid clobbering this interpreter's display hook, and ensure it's interactive. - # Why is this necessary? I'm not sure. + # Why is this necessary? interpreter = code.InteractiveInterpreter() def smt(s): diff --git a/thunder/tests/test_jit_functional.py b/thunder/tests/test_jit_functional.py index e65ed69aec..7fece51f72 100644 --- a/thunder/tests/test_jit_functional.py +++ b/thunder/tests/test_jit_functional.py @@ -292,7 +292,7 @@ def test_binary_ops_compare_numbers(): def test_binary_ops_int_numbers(): - # Issue https://github.com/Lightning-AI/lightning-thunder/issues/594 for more ops + # TODO: see issue "Implement logical and arithmetic left and right shifts" # "<<", ">>", int_ops = ["+", "&", "//", "*", "%", "|", "**", "-", "/", "^"] @@ -1574,7 +1574,7 @@ def foo(a, b): assert_close(expected, actual) -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2191") +@pytest.mark.xfail(reason='issue: "jit-eager: allow sets as a return value"') def test_return_set(): def foo(a, b): return {a, b} @@ -2453,7 +2453,7 @@ def foo(): jfoo() -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2184") +@pytest.mark.xfail(reason='issue: "sharp edges: loading closures"') def test_input_closure_sharp_edge(): x = 5 @@ -2486,7 +2486,7 @@ def _test_fn_global_no_sharp_edge_fn(): return 7 -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2189") +@pytest.mark.xfail(reason='issue: "sharp edge: allow function and module loads"') def test_fn_global_no_sharp_edge(): def foo(x): return x + _test_fn_global_no_sharp_edge_fn() diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 2c591d3680..28fc8a54fc 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -23,6 +23,7 @@ import thunder.core.prims as prims from thunder import pytorch_executor, nvfuser_executor from thunder.executors.sdpaex import sdpa_ex +from thunder.core.jit_ext import JITSharpEdgeError # @@ -49,6 +50,25 @@ def skipif_not_pytorch_2_1(f): )(f) +def test_jitting_through_opaque_torch_symbols_sharp_edge(): + def no_sharp_edge(x): + # randn_like is in ltorch + return torch.randn_like(x) + + def sharp_edge(x): + # rand_like is not yet in ltroch + return torch.rand_like(x) + + x = torch.rand(1) + + jno_sharp_edge = thunder.jit(no_sharp_edge, sharp_edges="error") + jno_sharp_edge(x) + + jsharp_edge = thunder.jit(sharp_edge, sharp_edges="error") + with pytest.raises(JITSharpEdgeError): + jsharp_edge(x) + + def test_binary_add_tensors(): def foo(a, b): return a + b @@ -330,7 +350,7 @@ def foo(a, b): jfoo = thunder.jit(foo) # TODO Add test for bool - # See https://github.com/Lightning-AI/lightning-thunder/issues/1990 + # see issue "Binary addition on booleans should promote to an integer" cases = ( (2, 3), (2.1, 3.4), @@ -377,7 +397,7 @@ def foo(a, b): jfoo = thunder.jit(foo) # TODO Add test for bool - # See https://github.com/Lightning-AI/lightning-thunder/issues/1990 + # see issue "Binary addition on booleans should promote to an integer" cases = ( (2, 3), (2.1, 3.4), @@ -394,7 +414,7 @@ def foo(a, b): _test_add_global_global = 2 -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/1935", raises=BaseException) +@pytest.mark.xfail(reason='"disallow global reads and writes (temporarily)"', raises=BaseException) def test_global_fails(): def foo(): return _test_add_global_global @@ -405,7 +425,10 @@ def foo(): jfoo() -@pytest.mark.xfail(reason="https://github.com/Lightning-AI/lightning-thunder/issues/1936", raises=BaseException) +@pytest.mark.xfail( + reason='"Raise an error when a program attempts to write to a nonlocal that was captured from outside the interpreter"', + raises=BaseException, +) def test_nonlocal_outside_interpreter_fails(): def foo(): x = 3 diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 971201bcba..15760f2948 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -41,7 +41,7 @@ def test_nanogpt_complete(executor, device, dtype): # TODO Investigate grad inconsistency # TODO: Add float16 and bfloat16 comparison tests here and to all other tests in # this file. -# https://github.com/Lightning-AI/lightning-thunder/issues/907 +# See issue "Add half precision dtype tests to test_networks.py" @instantiate(dtypes=(thunder.float32,)) def test_nanogpt_complete_autograd(executor, device, dtype): tdtype = ttorch.to_torch_dtype(dtype) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 3cbe631bb9..cdeac92a29 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -315,7 +315,7 @@ def func(w, x, y, z): @instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,)) def test_cse_rematerialization(executor, device, _): - # Unit test for https://github.com/Lightning-AI/lightning-thunder/issues/2046 + # Unit test for "llama2.c example failed with bookend disabled." from thunder.tests.llama2_model import Transformer, ModelArgs from thunder.core.pytree import tree_flatten @@ -338,10 +338,10 @@ def test_cse_rematerialization(executor, device, _): x = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) y = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) - compiled_func = thunder.compile( + compiled_func = thunder.jit( model.eval(), - disable_torch_autograd_support=True, - executors_list=executor.executors_list(), + disable_torch_autograd=True, + executors=executor.executors_list(), nv_enable_bookend=False, ) compiled_func(x, y) @@ -356,11 +356,11 @@ def test_cse_rematerialization(executor, device, _): assert len(fusion_bsyms) == 11 # fusion groups 1 and 6 correspond with the apply_rotary_emb function # Nvfuser with recomputation should use precomputed cos and sin values. - assert len(fusion_bsyms[1].args) == len(fusion_bsyms[6].args) - assert fusion_bsyms[1].args[0].name == "freqs_cos" - assert fusion_bsyms[1].args[1].name == "freqs_sin" - assert fusion_bsyms[6].args[0].name == "freqs_cos" - assert fusion_bsyms[6].args[1].name == "freqs_sin" + assert len(fusion_bsyms[1].args) == len(fusion_bsyms[7].args) + assert fusion_bsyms[1].subsymbols[0].output.name == "freqs_cos" + assert fusion_bsyms[1].subsymbols[1].output.name == "freqs_sin" + assert fusion_bsyms[7].subsymbols[0].output.name == "freqs_cos" + assert fusion_bsyms[7].subsymbols[1].output.name == "freqs_sin" # Tests that two separated nvFuser regions can be merged when they don't depend @@ -614,8 +614,7 @@ def func(x: torch.Tensor, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: dtypes=NOTHING, executors=( nvFuserExecutor, - # NOTE torch executor does not have bookend optimization. - # See comment: https://github.com/Lightning-AI/lightning-thunder/issues/571#issuecomment-1610778432 + # NOTE We might want to do transpose bookend optimization for other executors than nvFuser. ), ) def test_bookend_meta_optimization(executor, device, _): diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index 0f8560278b..b94403fcf3 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -310,8 +310,8 @@ def test_find_cut_dropout(executor, device, _): ext_producer_outputs = find_external_producer_outputs(utils.consumers(trace), (), producer, consumer) cut = find_cut(ext_producer_outputs, producer, consumer) # Note t5 is the boolean mask for dropout. It should be chosen over the t6 - # that is the float32 mask. See this issue for the original problem: - # https://github.com/Lightning-AI/lightning-thunder/issues/706 + # that is the float32 mask. See this issue: "The Recomputation Algorithm on + # Dropout choses a float32 mask to save" assert cut == ("t0", "t5", "t9") diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index b9d2fed286..867ca35a76 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -37,7 +37,7 @@ def snippet_torch_consistency(op: OpInfo, torch_op, sample: SampleInput, comp: C thunder_result = op(*sample.args, **sample.kwargs) torch_result = torch_op(*sample.args, **sample.kwargs) - # TODO Review how lightning.compile returns Exception information + # TODO Review how thunder.jit returns Exception information if isinstance(thunder_result, Exception): raise thunder_result diff --git a/thunder/tests/test_transformer_engine_executor.py b/thunder/tests/test_transformer_engine_executor.py index 60aa6f64ea..41b6f83e36 100644 --- a/thunder/tests/test_transformer_engine_executor.py +++ b/thunder/tests/test_transformer_engine_executor.py @@ -70,7 +70,8 @@ def fn(x, w1, w2): assert_close(w2.grad, te_linear2.weight.grad) # Verifies te_linear was called - forward_trace, backward_trace = thunder.last_traces(cfn) + forward_trace = thunder.last_traces(cfn) + backward_trace = thunder.last_backward_traces(cfn) assert any(bsym.sym.name.startswith("te_linear") for bsym in forward_trace[-1].bound_symbols) assert any(bsym.sym.name.startswith("te_functional_linear_backward") for bsym in backward_trace[-1].bound_symbols) @@ -180,6 +181,6 @@ def foo(x, w): ) cfunc(x, w) - fwd_traces, _ = thunder.last_traces(cfunc) + fwd_traces = thunder.last_traces(cfunc) # Verify that we have replaced `prims.linear` with `te_linear` assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a4388d9903..efb6319d83 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -12,7 +12,7 @@ import opt_einsum -# Initialies the language context +# Initializes the language context from thunder.torch.langctx import register_method import thunder.clang as clang @@ -29,6 +29,7 @@ from thunder.core.symbol import Symbol from thunder.core.transforms import register_grad, put_grads from thunder.core.prims import get_grad, put_grad +from thunder.core.baseutils import run_once __all__ = [ "is_available", @@ -224,8 +225,9 @@ def _parse_to_device_and_dtype( dtype = to_dtype(dtype) # Case 1 -- tensor first else: - # See https://github.com/Lightning-AI/lightning-thunder/issues/317 - # It'd be nice to write torch.Tensor here instead of TensorProxy + # It'd be nice to write torch.Tensor here instead of TensorProxy. + # See issue "Translate isinstance(a, torch.Tensor) calls so that + # TensorProxies can pass as torch.Tensors" utils.check_type(tensor_dtype_or_device, TensorProxy) device_ = tensor_dtype_or_device.device if device is None else to_device(device) dtype_ = tensor_dtype_or_device.true_dtype if dtype is None else to_dtype(dtype) @@ -413,7 +415,8 @@ def multinomial( ) -> TensorLike: utils.check(out is None, lambda: "Non-None out is not supported", NotImplementedError) - # See https://github.com/Lightning-AI/lightning-thunder/issues/2258 + # See issue "randomness: enable PyTorch generators for operations like + # multinomial" utils.check( generator is None, lambda: f"multinomial does not yet support specifying a generator", NotImplementedError ) @@ -430,7 +433,7 @@ def multinomial( # TODO Maybe update this to return an offset of how far to advance the seed to acquire new values -# See https://github.com/Lightning-AI/lightning-thunder/issues/1360 +# See issue "Maybe return offset from thunder.torch.uniform_philox" @torchsymbol(is_method=False, id="torch.uniform_philox") def uniform_philox( shape: Sequence[int], @@ -884,6 +887,33 @@ def squeeze(a: TensorLike, /, dim: None | int | Sequence[int] = None) -> TensorL return clang.squeeze(a, dims) +@torchsymbol(torch.t, is_method=True) +def t(a: TensorLike, /) -> TensorLike: + utils.check( + a.ndim <= 2, + lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D", + RuntimeError, + ) + return prims.transpose(a, (1, 0)) if a.ndim == 2 else a + + +@run_once +def warn_ndim_not_2(): + warnings.warn( + "The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and will throw an error in a future release." + "Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor." + ) + + +def reverse_dims_T(a: TensorLike, /) -> TensorLike: + if a.ndim != 2: + warn_ndim_not_2() + return a if a.ndim < 2 else prims.transpose(a, tuple(reversed(range(a.ndim)))) + + +register_method("T", reverse_dims_T) + + # TODO Add type annotations # See https://pytorch.org/docs/master/generated/torch.tensor_split.html @torchsymbol(torch.tensor_split, is_method=True) @@ -1362,9 +1392,9 @@ def zeta(a, b, /): # For calculate op1(a, op2(value, op2(b, c))) by promoting all input tensors at once # NOTE use this explicit type promotion because a direct combination of add/mul will have a redundant cast, -# which may lead to accuracy problems, see: -# https://github.com/Lightning-AI/lightning-thunder/pull/1155#discussion_r1342653591 for details -# TODO remove this when the optimization pass is ready: https://github.com/Lightning-AI/lightning-thunder/issues/1178 +# which may lead to accuracy problems. +# TODO remove after issue "Redundant cast removal could be performed through metadata-only +# operations, like broadcasting" is resolved def addcmul_addcdiv_helper( a, b, c, op1, op2, *, value=None, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ): @@ -2840,8 +2870,9 @@ def _avg_pool_helper( # Dimensionality of the kernel. kernel_numel = reduce(operator.mul, kernel_size, 1) - # nn.functional.avg_pool does not have `divisor_override` for some reason. - # TODO: seems like an oversight from PyTorch and/or 1d case is very niche. + # nn.functional.avg_pool does not have `divisor_override`. + # TODO: look into PyTorch side; is this behavior deliberate? Could be that + # 1D case is niche. # If needed, handle it with checks and transforms. For now unconditionally # override value with kernel_numel. if divisor_override is None or dim == 1: @@ -3003,7 +3034,8 @@ def _dropout_helper(a, p): # TODO Add annotations, make not a prim # The backward decomposition of cross_entropy cannot be efficiently fused, so we have this cross_entropy_backward # primitive. Executors can override the primitive using internal implementations. -# See https://github.com/Lightning-AI/lightning-thunder/issues/660 +# See issue "Cross_entropy is decomposed for backward but the decomposition is +# not fusible currently" @torchsymbol("cross_entropy_backward", id="cross_entropy_backward", is_prim=True) def cross_entropy_backward(g, a, /, target, weight, reduction, ignore_index, label_smoothing): return TensorProxy(like=g, shape=a.shape) @@ -3571,7 +3603,8 @@ def log_softmax(a: TensorLike, /, dim: int, *, dtype: None | dtypeLike = None) - # TODO Update annotations and consider moving to torchex # We improve the efficiency of cross_entropy backward decomposition by adding the log_softmax_backward # and nll_loss_backward primitives. Executors can override the primitives using internal implementations. -# See https://github.com/Lightning-AI/lightning-thunder/issues/660 +# See issue "Cross_entropy is decomposed for backward but the decomposition is +# not fusible currently" @torchsymbol("log_softmax_backward", id="log_softmax_backward") def log_softmax_backward(g: TensorProxy, /, output: TensorProxy, dim: int, dtype: dtypeLike) -> TensorLike: dtype: dtypes.dtype = to_dtype(dtype) @@ -3817,7 +3850,7 @@ def softmax(a: TensorLike, /, dim: int, *, dtype: None | dtypeLike = None) -> Te if torch.distributed.is_available(): DistributedReduceOpLike = str | torch.distributed.ReduceOp | dist_prims.DistributedReduceOps - # string name, PyTorch enum value, lightning.compile enum value + # string name, PyTorch enum value, thunder.jit enum value _reduceop_triples = (("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM),) def to_thunder_distributed_reduce_op(op: DistributedReduceOpLike | None):