diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 5766d8ca6..376f93938 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -33,7 +33,6 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -47,16 +46,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install pip deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -68,14 +63,17 @@ jobs: pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - # Test with nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler \ iree-base-runtime + pip freeze + - name: Run llama tests - run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --iree-device=hip://7 --html=out/llm/llama/benchmark/index.html + run: | + source ${VENV_DIR}/bin/activate + pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --iree-device=hip://7 --html=out/llm/llama/benchmark/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml index 697c47928..ddbcc204a 100644 --- a/.github/workflows/ci-llama-quick-tests.yaml +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -33,7 +33,6 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -47,16 +46,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install pip deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -68,14 +63,17 @@ jobs: pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - # Test with nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler \ iree-base-runtime + pip freeze + - name: Run llama 8b f16 decomposed test - run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --iree-device=hip://0 --run-quick-llama-test + run: | + source ${VENV_DIR}/bin/activate + pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --iree-device=hip://0 --run-quick-llama-test - name: Upload llama executable files uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml index 708cb3885..102ef7817 100644 --- a/.github/workflows/ci-sdxl.yaml +++ b/.github/workflows/ci-sdxl.yaml @@ -35,9 +35,9 @@ env: LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ jobs: - build-and-test: - name: Build and test - runs-on: mi300-sdxl-kernel + install-and-test: + name: Install and test + runs-on: mi300x-4 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -55,53 +55,27 @@ jobs: sudo apt install ninja -y fi - - name: Checkout IREE repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: iree-org/iree - path: ${{ env.IREE_REPO_DIR }} - submodules: false - ref: iree-3.1.0rc20241204 - - - name: Initalize IREE submodules - working-directory: ${{ env.IREE_REPO_DIR }} - run : | - git submodule update --init --depth 1 -- third_party/benchmark - git submodule update --init --depth 1 -- third_party/cpuinfo/ - git submodule update --init --depth 1 -- third_party/flatcc - git submodule update --init --depth 1 -- third_party/googletest - git submodule update --init --depth 1 -- third_party/hip-build-deps/ - - name: Setup Python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: "3.12" cache: "pip" - - name: Install Python packages - # TODO: Switch to `pip install -r requirements.txt -e shortfin/`. + + - name: Install requirements working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | pip install -r requirements-tests.txt pip install -r requirements-iree-compiler.txt pip freeze - - name: Build shortfin (full) + - name: Install shortfin working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - mkdir build - cmake -GNinja \ - -S. \ - -Bbuild \ - -DCMAKE_C_COMPILER=clang-18 \ - -DCMAKE_CXX_COMPILER=clang++-18 \ - -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON - cmake --build build --target all - pip install -v -e build/ + pip install --no-compile -e . - - name: Test shortfin (full) + - name: Test apps/sd/e2e_test working-directory: ${{ env.LIBSHORTFIN_DIR }} + env: + HIP_VISIBLE_DEVICES: 0 run: | - ctest --timeout 30 --output-on-failure --test-dir build - HIP_VISIBLE_DEVICES=0 pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu + pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml index e6af1e73c..d53189483 100644 --- a/.github/workflows/ci-sglang-benchmark.yml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -45,7 +45,7 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -54,16 +54,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ matrix.version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install pip deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -81,11 +77,15 @@ jobs: iree-base-runtime==3.1.0rc20241204 \ "numpy<2.0" - - name: Install SGLang - run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + # Install SGLang + pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + + pip freeze - name: Run Shortfin Benchmark Tests - run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py --log-cli-level=INFO --html=shortfin_index.html --self-contained-html + run: | + source ${VENV_DIR}/bin/activate + pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py --log-cli-level=INFO --html=shortfin_index.html --self-contained-html - name: Upload pytest report uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 @@ -103,8 +103,6 @@ jobs: defaults: run: shell: bash - env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -114,13 +112,6 @@ jobs: with: python-version: ${{matrix.version}} - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ matrix.version }} - - name: Install SGLang run: | python -m pip install --no-compile --upgrade pip diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml index 154657504..36a59779a 100644 --- a/.github/workflows/ci-sglang-integration-tests.yml +++ b/.github/workflows/ci-sglang-integration-tests.yml @@ -34,7 +34,7 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -43,16 +43,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install pip deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -69,11 +65,13 @@ jobs: iree-base-runtime \ "numpy<2.0" - - name: Install SGLang - run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + # Install SGLang and sentence_transformers + pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + pip install sentence_transformers - - name: Install sentence_transformers - run: pip install sentence_transformers + pip freeze - name: Run Integration Tests - run: pytest -v app_tests/integration_tests/llm/sglang --log-cli-level=INFO + run: | + source ${VENV_DIR}/bin/activate + pytest -v app_tests/integration_tests/llm/sglang --log-cli-level=INFO diff --git a/.github/workflows/ci-shark-ai.yml b/.github/workflows/ci-shark-ai.yml index 7ec69e13b..e662d125b 100644 --- a/.github/workflows/ci-shark-ai.yml +++ b/.github/workflows/ci-shark-ai.yml @@ -33,7 +33,7 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -42,16 +42,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install pip deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -70,5 +66,9 @@ jobs: iree-base-compiler \ iree-base-runtime + pip freeze + - name: Run LLM Integration Tests - run: pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO + run: | + source ${VENV_DIR}/bin/activate + pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 7433bf167..565bf3352 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -35,7 +35,7 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -44,16 +44,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install sharktank deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -72,8 +68,12 @@ jobs: iree-base-compiler \ iree-base-runtime + pip freeze + - name: Run perplexity test with IREE - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html + run: | + source ${VENV_DIR}/bin/activate + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 @@ -97,7 +97,7 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -106,16 +106,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install sharktank deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -128,7 +124,9 @@ jobs: -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - name: Run perplexity test with Torch - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html + run: | + source ${VENV_DIR}/bin/activate + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index 64043c2ec..6331e0709 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -34,7 +34,7 @@ jobs: run: shell: bash env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -43,16 +43,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + - name: Create Python venv + run: python -m venv ${VENV_DIR} - name: Install sharktank deps run: | + source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU @@ -71,5 +67,9 @@ jobs: iree-base-compiler \ iree-base-runtime + pip freeze + - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + run: | + source ${VENV_DIR}/bin/activate + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index fd56ec872..620c15672 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -17,5 +17,6 @@ from .ffn_block import FFN from .ffn_moe_block import FFNMOE from .mixture_of_experts_block import MoeBlock +from .mmdit import MMDITDoubleBlock from .configs import * diff --git a/sharktank/sharktank/layers/mmdit.py b/sharktank/sharktank/layers/mmdit.py new file mode 100644 index 000000000..0b0750549 --- /dev/null +++ b/sharktank/sharktank/layers/mmdit.py @@ -0,0 +1,146 @@ +# Copyright 2024 Black Forest Labs. Inc. and Flux Authors +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""MMDIT Layers adapted from black-forest-labs' flux implementation +https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +""" + +import torch.nn.functional as F +import torch +from torch import Tensor + +from .. import ops + +from .base import Theta, ThetaLayer +from .linear import LinearLayer +from .modulation import ModulationLayer +from .norm import RMSNormLayer +from .paged_llama_attention_block import PagedLlamaAttentionBlock + + +def qk_norm(q, k, v, rms_q, rms_k): + return rms_q(q).to(v), rms_k(k).to(v) + + +# TODO: Work on unifying with the current RoPE layer +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +def attention(q, k, v, pe): + q, k = apply_rope(q, k, pe) # todo + + x = ops.scaled_dot_product_attention( + q=q, k=k, v=v, a=None, is_causal=True, scale=None + ) + x = ops.permute(x, (0, 2, 1, 3)) + x = x.view(x.shape[0], x.shape[1], -1) + + return x + + +class MMDITDoubleBlock(ThetaLayer): + def __init__(self, theta, num_heads: int): + super().__init__(theta) + + self.num_heads = num_heads + self.add_module("img_mod", ModulationLayer(theta("img_mod"), double=True)) + self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv"))) + self.add_module( + "img_attn_norm_q", + RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6), + ) + self.add_module( + "img_attn_norm_k", + RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6), + ) + self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj"))) + + self.add_module("img_mlp1", LinearLayer(theta("img_mlp.0"))) + self.add_module("img_mlp2", LinearLayer(theta("img_mlp.2"))) + + self.add_module("txt_mod", ModulationLayer(theta("txt_mod"), double=True)) + self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv"))) + self.add_module( + "txt_attn_norm_q", + RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6), + ) + self.add_module( + "txt_attn_norm_k", + RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6), + ) + self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj"))) + + self.add_module("txt_mlp1", LinearLayer(theta("txt_mlp.0"))) + self.add_module("txt_mlp2", LinearLayer(theta("txt_mlp.2"))) + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = ops.layer_norm(img, None, None, eps=1e-6) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn_qkv(img_modulated) + img_qkv_2 = img_qkv.view( + img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1 + ) # + img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4)) + img_q, img_k, img_v = img_qkv_3 + img_q, img_k = qk_norm( + img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k + ) + + # prepare text for attention + txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_qkv_2 = txt_qkv.view( + txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1 + ) # + txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4)) + txt_q, txt_k, txt_v = txt_qkv_3 + txt_q, txt_k = qk_norm( + txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k + ) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the image blocks + # TODO: Refactor this for code reuse with the txt blocks + img = img + img_mod1.gate * self.img_attn_proj(img_attn) + img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm( + img, None, None, eps=1e-6 + ) + img_mod2.shift + img_mlp_out1 = self.img_mlp1(img_mlp_in) + img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1) + img_mlp_out3 = self.img_mlp2(img_mlp_out2) + img = img + img_mod2.gate * img_mlp_out3 + + # calculate the text blocks + txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn) + txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm( + txt, None, None, eps=1e-6 + ) + txt_mod2.shift + txt_mlp_out1 = self.txt_mlp1(txt_mlp_in) + # TODO: Unify with modulation layer by taking act_fn as an arg + txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1) + txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2) + txt = txt + txt_mod2.gate * txt_mlp_out3 + + return img, txt diff --git a/sharktank/sharktank/layers/modulation.py b/sharktank/sharktank/layers/modulation.py new file mode 100644 index 000000000..7ef7adfa1 --- /dev/null +++ b/sharktank/sharktank/layers/modulation.py @@ -0,0 +1,42 @@ +# Copyright 2024 Black Forest Labs. Inc. and Flux Authors +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Modulation Layer adapted from black-forest-labs' flux implementation +https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +""" + +import torch +import torch.nn.functional as F + +from .. import ops + +from .base import Theta, ThetaLayer +from .linear import LinearLayer + + +class ModulationOut: + def __init__(self, shift, scale, gate): + self.shift = shift + self.scale = scale + self.gate = gate + + +class ModulationLayer(ThetaLayer): + def __init__(self, theta: Theta, double: bool): + super().__init__(theta) + + self.is_double = double + self.multiplier = 6 if double else 3 + self.add_module("lin", LinearLayer(theta("lin"))) + + def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]: + silu_result = ops.elementwise(F.silu, vec) + out = self.lin(silu_result)[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py index e2fc79d78..a21d5bf85 100644 --- a/sharktank/sharktank/layers/testing.py +++ b/sharktank/sharktank/layers/testing.py @@ -49,3 +49,82 @@ def make_llama_attention_block_theta( ), } ) + + +def make_mmdit_double_block_theta(dtype: torch.dtype | None = None) -> Theta: + return Theta( + { + "img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "img_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072,), dtype=dtype) + ), + "img_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 3072), dtype=dtype) + ), + "img_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((9216,), dtype=dtype) + ), + "img_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((9216, 3072), dtype=dtype) + ), + "img_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((12288), dtype=dtype) + ), + "img_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((12288, 3072), dtype=dtype) + ), + "img_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072), dtype=dtype) + ), + "img_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 12288), dtype=dtype) + ), + "img_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((18432,), dtype=dtype) + ), + "img_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((18432, 3072), dtype=dtype) + ), + "txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "txt_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072,), dtype=dtype) + ), + "txt_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 3072), dtype=dtype) + ), + "txt_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((9216,), dtype=dtype) + ), + "txt_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((9216, 3072), dtype=dtype) + ), + "txt_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((12288), dtype=dtype) + ), + "txt_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((12288, 3072), dtype=dtype) + ), + "txt_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072), dtype=dtype) + ), + "txt_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 12288), dtype=dtype) + ), + "txt_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((18432,), dtype=dtype) + ), + "txt_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((18432, 3072), dtype=dtype) + ), + } + ) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index d117ada23..47e737fb1 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -304,16 +304,26 @@ def interpolate_default( ) -@layer_norm.override(Tensor, Tensor, Tensor) def layer_norm_default(input, weight, bias, *, eps): input = unbox_tensor(input) - weight = unbox_tensor(weight) - bias = unbox_tensor(bias) + if weight is not None: + weight = unbox_tensor(weight) + else: + weight = torch.ones(input.shape, dtype=input.dtype) + if bias is not None: + bias = unbox_tensor(bias) + else: + bias = torch.zeros(input.shape, dtype=input.dtype) return F.layer_norm( input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps ) +layer_norm.override(Tensor)(layer_norm_default) +layer_norm.override(Tensor, Tensor)(layer_norm_default) +layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default) + + # Linear def linear_default(input, weight, bias, *, accum_dtype) -> Tensor: input = unbox_tensor(input) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 408f00ec7..dc7fb108a 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -582,12 +582,14 @@ def layer_norm( def _layer_norm_trampoline( d: SignatureDispatcher, input: AnyTensor, - weight: AnyTensor, + weight: Optional[AnyTensor], bias: Optional[AnyTensor], *, eps: float, ): - tensors = [input, weight] + tensors = [input] + if weight is not None: + tensors.append(bias) if bias is not None: tensors.append(bias) for override in d.find_overrides(tensors): diff --git a/sharktank/tests/layers/mmdit_test.py b/sharktank/tests/layers/mmdit_test.py new file mode 100644 index 000000000..5bd5ce39a --- /dev/null +++ b/sharktank/tests/layers/mmdit_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest + +import torch + +from iree.turbine import aot +from sharktank.layers import ( + MMDITDoubleBlock, +) +import sharktank.ops as ops +from sharktank.layers.testing import ( + make_mmdit_double_block_theta, +) +from sharktank.types.tensors import DefaultPrimitiveTensor + + +class MMDITTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(12345) + self.hidden_size = 3072 + self.num_heads = 24 + self.batch_size = 3 + + def testDoubleExport(self): + + theta = make_mmdit_double_block_theta() + mmdit = MMDITDoubleBlock( + theta=theta, + num_heads=self.num_heads, + ) + + img = torch.rand([self.batch_size, 1024, self.hidden_size]) + txt = torch.rand([self.batch_size, 512, self.hidden_size]) + vec = torch.rand([self.batch_size, self.hidden_size]) + rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2]) + mmdit.forward(img, txt, vec, rot) + fxb = aot.FxProgramsBuilder(mmdit) + + @fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False) + def _(model, img, txt, vec, rot) -> torch.Tensor: + return model.forward(img, txt, vec, rot) + + output = aot.export(fxb) + output.verify() + asm = str(output.mlir_module) + + +if __name__ == "__main__": + unittest.main() diff --git a/shortfin/build_tools/build_linux_package.sh b/shortfin/build_tools/build_linux_package.sh index afaa1e9fb..91b944e51 100755 --- a/shortfin/build_tools/build_linux_package.sh +++ b/shortfin/build_tools/build_linux_package.sh @@ -14,9 +14,10 @@ # Build everything (all python versions): # sudo ./build_tools/build_linux_package.sh # -# Build specific Python versions to custom directory: +# Build specific Python versions to custom directory, with tracing enabled: # OVERRIDE_PYTHON_VERSIONS="cp312-cp312 cp313-cp313" \ # OUTPUT_DIR="/tmp/wheelhouse" \ +# SHORTFIN_ENABLE_TRACING="ON" \ # sudo -E ./build_tools/build_linux_package.sh # # Valid Python versions match a subdirectory under /opt/python in the docker @@ -40,6 +41,8 @@ ARCH="$(uname -m)" MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux_2_28_${ARCH}:latest}" PYTHON_VERSIONS="${OVERRIDE_PYTHON_VERSIONS:-cp311-cp311 cp312-cp312 cp313-cp313}" OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}" +CACHE_DIR="${CACHE_DIR:-}" +SHORTFIN_ENABLE_TRACING="${SHORTFIN_ENABLE_TRACING:-ON}" function run_on_host() { echo "Running on host" @@ -50,12 +53,23 @@ function run_on_host() { OUTPUT_DIR="$(cd "${OUTPUT_DIR}" && pwd)" echo "Outputting to ${OUTPUT_DIR}" mkdir -p "${OUTPUT_DIR}" + + # Setup cache as needed. + extra_args="" + if ! [ -z "$CACHE_DIR" ]; then + echo "Setting up host cache dir ${CACHE_DIR}" + mkdir -p "${CACHE_DIR}/ccache" + extra_args="${extra_args} -v ${CACHE_DIR}:${CACHE_DIR} -e CACHE_DIR=${CACHE_DIR}" + fi + docker run --rm \ -v "${REPO_ROOT}:${REPO_ROOT}" \ -v "${OUTPUT_DIR}:${OUTPUT_DIR}" \ -e __MANYLINUX_BUILD_WHEELS_IN_DOCKER=1 \ -e "OVERRIDE_PYTHON_VERSIONS=${PYTHON_VERSIONS}" \ -e "OUTPUT_DIR=${OUTPUT_DIR}" \ + -e "SHORTFIN_ENABLE_TRACING=${SHORTFIN_ENABLE_TRACING}" \ + ${extra_args} \ "${MANYLINUX_DOCKER_IMAGE}" \ -- ${THIS_DIR}/${SCRIPT_NAME} @@ -72,6 +86,23 @@ function run_in_docker() { echo "Using python versions: ${PYTHON_VERSIONS}" local orig_path="${PATH}" + # Configure caching. + if [ -z "$CACHE_DIR" ]; then + echo "Cache directory not configured. No caching will take place." + else + # TODO: include this in the dockerfile we use so it gets cached + install_ccache + + # TODO: debug low cache hit rate (~30% hits out of 98% cacheable) on CI + mkdir -p "${CACHE_DIR}" + CACHE_DIR="$(cd ${CACHE_DIR} && pwd)" + echo "Caching build artifacts to ${CACHE_DIR}" + export CCACHE_DIR="${CACHE_DIR}/ccache" + export CCACHE_MAXSIZE="2G" + export CMAKE_C_COMPILER_LAUNCHER=ccache + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + fi + # Build phase. echo "******************** BUILDING PACKAGE ********************" for python_version in ${PYTHON_VERSIONS}; do @@ -82,14 +113,44 @@ function run_in_docker() { fi export PATH="${python_dir}/bin:${orig_path}" echo ":::: Python version $(python --version)" + clean_wheels "shortfin" "${python_version}" build_shortfin run_audit_wheel "shortfin" "${python_version}" + + if ! [ -z "$CACHE_DIR" ]; then + echo "ccache stats:" + ccache --show-stats + fi done } +function install_ccache() { + # This gets an old version. + # yum install -y ccache + + CCACHE_VERSION="4.10.2" + + if [[ "${ARCH}" == "x86_64" ]]; then + curl --silent --fail --show-error --location \ + "https://github.com/ccache/ccache/releases/download/v${CCACHE_VERSION}/ccache-${CCACHE_VERSION}-linux-${ARCH}.tar.xz" \ + --output ccache.tar.xz + + tar xf ccache.tar.xz + cp ccache-${CCACHE_VERSION}-linux-${ARCH}/ccache /usr/local/bin + elif [[ "${ARCH}" == "aarch64" ]]; then + # Latest version of ccache is not released for arm64, built it + git clone --depth 1 --branch "v${CCACHE_VERSION}" https://github.com/ccache/ccache.git + mkdir -p ccache/build && cd "$_" + cmake -G "Ninja" -DCMAKE_BUILD_TYPE=Release .. + ninja + cp ccache /usr/bin/ + fi +} + function build_shortfin() { - export SHORTFIN_ENABLE_TRACING=ON + # Note: The SHORTFIN_ENABLE_TRACING environment variable should have been + # forwarded from the host environment into Docker above. python -m pip wheel --disable-pip-version-check -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/shortfin" }