diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 4b13c668f618..000000000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,83 +0,0 @@ -# Codeowners for IREE Github Repository. -# The listed owners will automatically be added as reviewers to PRs that modify -# paths matching the specified patterns. -# Refer to https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners -# for syntax of this file (tl;dr: syntax is like .gitignore. Last matching rule -# takes precedence). -# Because of the precedence, rules for directories are listed topologically. -# @ghost is used to make a pattern have no owners. It is a sentinel GitHub user -# that takes the place of deleted users. - -# No global owners because we don't really want e.g. changing the root -# CMakeLists.txt file to always ping a bunch of people. - -# Third-Party Code -/.gitmodules @ScottTodd @stellaraccident -/third_party/ @ScottTodd @stellaraccident -# Except for routinely-updated submodules -/third_party/llvm-project @ghost -/third_party/llvm-project.branch-pin @ghost -/third_party/stablehlo @ghost -/third_party/torch-mlir @ghost - -# Bindings -/runtime/bindings/python/ @stellaraccident -/runtime/bindings/tflite/ @benvanik - -# Integrations -/integrations/ @benvanik @stellaraccident -/integrations/tensorflow/ @stellaraccident -/integrations/tensorflow/test/**/iree_tfl_tests/ @rsuderman - -# Experimental -# It's experimental, but we still don't want any old directory added here. -/experimental/ @benvanik @stellaraccident -/experimental/cpu_ukernel/ @bjacob -/experimental/cuda2/ @antiagainst -/experimental/dispatch_profiler/ @manishucsd -/experimental/rocm/ @benvanik -/experimental/web/ @ScottTodd -/experimental/webgpu/ @benvanik @ScottTodd - -# Infra Top-Level Directories -/build_tools/ @ScottTodd @pzread -/build_tools/benchmarks/ @antiagainst @pzread -/build_tools/python/ @pzread -/build_tools/python_deploy/ @stellaraccident -/build_tools/scripts/ @ScottTodd -/build_tools/third_party/ @ScottTodd @stellaraccident -/.github/ @ScottTodd - -# llvm-external-projects -/llvm-external-projects/ @stellaraccident -/llvm-external-projects/iree-dialects/ @MaheshRavishankar -/llvm-external-projects/iree-dialects/**/Dialect/LinalgExt/ @hanhanW @MaheshRavishankar -/llvm-external-projects/iree-dialects/test/iree_linalgext @hanhanW @MaheshRavishankar - -# Other Top-Level Directories -/docs/ @ScottTodd -/samples/ @ScottTodd -/tools/ @benvanik - -# Compiler -/compiler/src/iree/compiler/ @benvanik -/compiler/src/iree/compiler/Codegen/ @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/Common @hanhanW @dcaballe -/compiler/src/iree/compiler/Codegen/Common/GPU @antiagainst @qedawkins -/compiler/src/iree/compiler/Codegen/LLVMCPU/ @dcaballe @hanhanW @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/LLVMGPU/ @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/SPIRV/ @antiagainst @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/TransformStrategies/ @qedawkins @MaheshRavishankar -/compiler/src/iree/compiler/ConstEval/ @hanhanW @stellaraccident -/compiler/src/iree/compiler/Dialect/Flow/ @hanhanW @MaheshRavishankar -/compiler/src/iree/compiler/Dialect/Vulkan/ @antiagainst -/compiler/src/iree/compiler/GlobalOptimization/ @hanhanW -/compiler/src/iree/compiler/InputConversion/ @MaheshRavishankar @stellaraccident -/compiler/plugins/input/StableHLO/ @hanhanW @MaheshRavishankar @rsuderman -/compiler/plugins/input/TOSA/ @MaheshRavishankar @rsuderman - -# Runtime -/runtime/src/iree/ @benvanik -/runtime/src/iree/hal/cts/ @ScottTodd -/runtime/src/iree/hal/drivers/metal/ @antiagainst -/runtime/src/iree/hal/drivers/vulkan/ @antiagainst @ScottTodd diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 7e044ea4a1d6..172f09f9c7ea 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -39,85 +39,81 @@ jobs: matrix: include: # Ubuntu packages. - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 - build-package: main-dist-linux - experimental: false - - runs-on: [self-hosted, arm64, os-family=Linux, runner-group=postsubmit] - build-family: linux-aarch64 - build-package: main-dist-linux - experimental: true - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 + - runs-on: icelake + build-family: linux build-package: py-compiler-pkg experimental: false - - runs-on: [self-hosted, arm64, os-family=Linux, runner-group=postsubmit] - build-family: linux-aarch64 - build-package: py-compiler-pkg - experimental: true - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 + - runs-on: icelake + build-family: linux build-package: py-runtime-pkg experimental: false - - runs-on: [self-hosted, arm64, os-family=Linux, runner-group=postsubmit] - build-family: linux-aarch64 - build-package: py-runtime-pkg - experimental: true - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux-x86_64 - build-package: py-tf-compiler-tools-pkg - experimental: false - # MacOS packages. - - runs-on: - - ${{ github.repository == 'openxla/iree' && 'self-hosted' || 'macos-11' }} - - os-family=macOS - - runner-group=postsubmit + # Macos packages. + - runs-on: MacStudio build-family: macos build-package: py-compiler-pkg experimental: true - - runs-on: - - ${{ github.repository == 'openxla/iree' && 'self-hosted' || 'macos-11' }} - - os-family=macOS - - runner-group=postsubmit + - runs-on: MacStudio build-family: macos build-package: py-runtime-pkg experimental: true # Windows packages. - runs-on: - - ${{ github.repository == 'openxla/iree' && 'windows-2022-64core' || 'windows-2022'}} + - ${{ github.repository == 'openxla/iree' && 'windows-2022-64core' || '7950X'}} build-family: windows build-package: py-compiler-pkg experimental: true - - runs-on: windows-2022 + - runs-on: 7950X build-family: windows build-package: py-runtime-pkg experimental: true + # Linux AArch64 packages. + - runs-on: linux-aarch64 + build-family: linux-aarch64 + build-package: py-compiler-pkg + experimental: false + - runs-on: linux-aarch64 + build-family: linux-aarch64 + build-package: py-runtime-pkg + experimental: false + + env: # These are also set in: build_tools/python_deploy/build_linux_packages.sh MANYLINUX_X86_64_IMAGE: ghcr.io/nod-ai/manylinux_x86_64:main MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux_2_28_aarch64 steps: + # Docker may leave root owned files + - name: Chown user + if: "matrix.build-family == 'linux-aarch64' || matrix.build-family == 'linux'" + run: | + sudo chown -R $USER:$USER $GITHUB_WORKSPACE - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 with: path: "c" # Windows can hit path length limits, so use a short path. submodules: true ref: ${{ github.event.inputs.commit }} + - uses: actions/setup-python@v4 + if: "matrix.build-family == 'windows'" + with: + python-version: '3.11' ########################################################################## # OS specific setup ########################################################################## - - name: Install dependencies (Windows) - if: "matrix.build-family == 'windows'" - shell: powershell - run: ./c/build_tools/python_deploy/install_windows_deps.ps1 + #- name: Install dependencies (Windows) + # if: "matrix.build-family == 'windows'" + # shell: powershell + # run: ./c/build_tools/python_deploy/install_windows_deps.ps1 - name: "Configure MSVC (Windows)" if: "matrix.build-family == 'windows'" uses: ilammy/msvc-dev-cmd@7315a94840631165970262a99c72cfb48a65d25d # v1.12.0 + with: + arch: x64 ########################################################################## # Write version_info.json @@ -240,6 +236,16 @@ jobs: [ -e ./bindist/* ] && rm ./bindist/* ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build compiler wheels (Linux-AArch64) + if: "matrix.build-package == 'py-compiler-pkg' && matrix.build-family == 'linux-aarch64'" + shell: bash + env: + package_suffix: ${{ github.event.inputs.package_suffix }} + packages: "iree-compiler" + output_dir: "${{ github.workspace }}/bindist" + run: | + ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build compiler wheels (MacOS) if: "matrix.build-package == 'py-compiler-pkg' && matrix.build-family == 'macos'" shell: bash @@ -288,10 +294,10 @@ jobs: path: ./bindist/* retention-days: 5 - # TODO: Upload the tar.bz2 files too when ready - - name: Upload Release Assets - if: github.event.inputs.release_id != '' - id: upload-release-assets + # TODO: One Window Release builds we build both compiler+runtime + - name: Upload Release Assets (Windows) + if: "github.event.inputs.release_id != '' && matrix.build-family == 'windows'" + id: upload-release-assets-windows uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 env: GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} @@ -300,6 +306,29 @@ jobs: # Only upload iree artifacts. assets_path: ./bindist/iree*.* + # TODO: Upload the tar.bz2 files too when ready + - name: Upload Release Assets (Compiler) + if: "github.event.inputs.release_id != '' && matrix.build-package == 'py-compiler-pkg' && matrix.build-family != 'windows'" + id: upload-release-assets-compiler + uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 + env: + GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + # Only upload iree artifacts. + assets_path: ./bindist/iree_compiler*.* + + - name: Upload Release Assets (Runtime) + if: "github.event.inputs.release_id != '' && matrix.build-package == 'py-runtime-pkg' && matrix.build-family != 'windows'" + id: upload-release-assets-runtime + uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 + env: + GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + # Only upload iree artifacts. + assets_path: ./bindist/iree_runtime*.* + validate_and_publish: name: "Trigger validate and publish release" needs: build_packages diff --git a/.github/workflows/publish_website.yml b/.github/workflows/publish_website.yml index 28803fb23fbf..e41b94627abe 100644 --- a/.github/workflows/publish_website.yml +++ b/.github/workflows/publish_website.yml @@ -44,13 +44,6 @@ jobs: with: python-version: 3.x cache: 'pip' - - id: "gcp-auth" - name: "Authenticating to Google Cloud" - uses: "google-github-actions/auth@v1" - with: - token_format: "access_token" - credentials_json: "${{ secrets.IREE_OSS_GITHUB_RUNNER_BASIC_TRUST_SERVICE_ACCOUNT_KEY }}" - create_credentials_file: false - name: Installing dependencies run: | pip install -r docs/website/requirements.txt @@ -60,14 +53,6 @@ jobs: ./build_tools/scripts/generate_release_index.py \ --repo="${GITHUB_REPOSITORY}" \ --output=docs/website/docs/pip-release-links.html - - name: Building documentation files - run: | - ./build_tools/github_actions/docker_run.sh \ - --env "IREE_CCACHE_GCP_TOKEN=${{ steps.gcp-auth.outputs.access_token }}" \ - --env "IREE_WRITE_REMOTE_CCACHE=1" \ - --env "CCACHE_NAMESPACE=gcr.io/iree-oss/base@sha256:796fb81a11ff7e7d057c93de468b74e48b6a9641aa19b7f7673c2772e8ea3b33" \ - gcr.io/iree-oss/base@sha256:796fb81a11ff7e7d057c93de468b74e48b6a9641aa19b7f7673c2772e8ea3b33 \ - ./docs/website/generate_extra_files.sh - name: Setting git config run: | git config --local user.email "iree-github-actions-bot@google.com" diff --git a/.github/workflows/sync.yml b/.github/workflows/sync.yml new file mode 100644 index 000000000000..2fb41e7b7c68 --- /dev/null +++ b/.github/workflows/sync.yml @@ -0,0 +1,69 @@ +name: 'Sync Upstream' + +on: + workflow_dispatch: + schedule: + - cron: '0 * * * *' + +jobs: + sync_upstream: + name: 'Sync Upstream' + runs-on: ubuntu-latest + steps: + - name: Checking out repository + uses: actions/checkout@v3 + with: + token: ${{ secrets.CI_WRITE_TOKEN }} + repository: nod-ai/shark-runtime + ref: main + fetch-depth: 0 + + - name: Setup git + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "SHARK bot" + + - name: Update main upstream + run: | + set -ex + git remote add upstream https://github.com/iree-org/iree + git pull --ff-only upstream main + + - name: Pushing changes + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.CI_WRITE_TOKEN }} + branch: main + repository: nod-ai/shark-runtime + + rebase_shark: + name: 'Rebase SHARK' + runs-on: ubuntu-latest + steps: + - name: Checking out repository + uses: actions/checkout@v3 + with: + token: ${{ secrets.CI_WRITE_TOKEN }} + repository: nod-ai/shark-runtime + ref: shark + fetch-depth: 0 + + - name: Setup git + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "SHARK bot" + + - name: Update shark upstream + run: | + set -ex + git remote add upstream https://github.com/iree-org/iree + git fetch upstream + git rebase upstream/main + + - name: Pushing changes + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.CI_WRITE_TOKEN }} + branch: shark + repository: nod-ai/shark-runtime + force_with_lease: true diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml index 41d19b40c9f9..e86c2665ea6e 100644 --- a/.github/workflows/validate_and_publish_release.yml +++ b/.github/workflows/validate_and_publish_release.yml @@ -16,100 +16,8 @@ on: required: true jobs: - validate_packages: - name: "Validate packages" - # TODO(jennik): Look into testing windows and macos builds. - runs-on: ubuntu-20.04 - steps: - - name: Download packages - id: download_packages - uses: dawidd6/action-download-artifact@5e780fc7bbd0cac69fc73271ed86edf5dcb72d67 # v2.26.0 - with: - github_token: ${{secrets.WRITE_ACCESS_TOKEN}} - workflow: build_package.yml - run_id: ${{ github.event.inputs.build_run_id }} - - name: Extract and display downloaded files - run: | - tar -xf artifact/iree-dist-${{ github.event.inputs.package_version }}-linux-x86_64.tar.xz - pwd - ls -R - - name: Set up python - id: set_up_python - uses: actions/setup-python@d27e3f3d7c64b4bbf8e4abfb9b63b83e846e0435 # v4.5.0 - with: - python-version: "3.9" - - name: Install python packages - id: install_python_packages - run: | - python -m pip install -f file://$PWD/artifact/ iree-compiler iree-runtime iree-tools-tflite iree-tools-tf - - name: Validate IREE Runtime Package - id: validate_runtime_package - run: | - echo "Testing default runtime:" - python -m iree.runtime._package_test - echo "Testing tracy runtime:" - # GH runners don't expose the TSC but we want to make sure the basic packaging - # works, so override the check with TRACY_NO_INVARIANT_CHECK=1 (per instructions - # if this is left off). - TRACY_NO_INVARIANT_CHECK=1 IREE_PY_RUNTIME=tracy \ - python -m iree.runtime._package_test - # Binaries from the tarball - - name: Run iree-benchmark-module - id: run_iree_benchmark_module - run: ./bin/iree-benchmark-module --help - - name: Run iree-benchmark-trace - id: run_iree_benchmark_trace - run: ./bin/iree-benchmark-trace --help - - name: Run iree-dump-module - id: run_iree_dump_module - run: ./bin/iree-dump-module --help - - name: Run iree-cpuinfo - id: run_iree_cpuinfo - run: ./bin/iree-cpuinfo - - name: Run iree-flatcc-cli - id: run_iree_flatcc_cli - run: ./bin/iree-flatcc-cli --help - - name: Run iree-opt - id: run_iree_opt - run: ./bin/iree-opt --help - - name: Run iree-run-mlir - id: run_iree_run_mlir - run: ./bin/iree-run-mlir --help - - name: Run iree-run-module - id: run_iree_run_module - run: ./bin/iree-run-module --help - - name: Run iree-run-trace - id: run_iree_run_trace - run: ./bin/iree-run-trace --help - - name: Run iree-tblgen - id: run_iree_tblgen - run: ./bin/iree-tblgen --help - - name: Run iree-compile - id: run_iree-compile - run: ./bin/iree-compile --help - # Console scripts from the wheels. - - name: Py iree-run-module - id: py_iree-run-module - run: iree-run-module --help - - name: Py iree-run-trace - id: py_iree-run-trace - run: iree-run-trace --help - - name: Py iree-benchmark-module - id: py_iree_benchmark_module - run: iree-benchmark-module --help - - name: Py iree-benchmark-trace - id: py_iree_benchmark_trace - run: iree-benchmark-trace --help - - name: Py iree-dump-module - id: py_iree_dump_module - run: iree-dump-module --help - - name: Py iree-cpuinfo - id: py_iree_cpuinfo - run: iree-cpuinfo - publish_release: name: "Publish release" - needs: validate_packages runs-on: ubuntu-20.04 steps: - name: Publish Release @@ -120,17 +28,3 @@ jobs: with: release_id: ${{ github.event.inputs.release_id }} - - name: Checking out repository - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 - with: - token: ${{ secrets.WRITE_ACCESS_TOKEN }} - # Get all history. Otherwise the latest-snapshot branch can't be - # fast-forwarded. - fetch-depth: 0 - - - name: Updating latest-snapshot branch - uses: ad-m/github-push-action@40bf560936a8022e68a3c00e7d2abefaf01305a6 # v0.6.0 - with: - github_token: ${{ secrets.WRITE_ACCESS_TOKEN }} - branch: latest-snapshot - force: true diff --git a/build_tools/benchmarks/reporting/requirements.txt b/build_tools/benchmarks/reporting/requirements.txt index 9dcb2d2452cd..cb7ee9904e47 100644 --- a/build_tools/benchmarks/reporting/requirements.txt +++ b/build_tools/benchmarks/reporting/requirements.txt @@ -1,2 +1,2 @@ pandas==1.5.0 -jinja2==2.11.3 \ No newline at end of file +jinja2==3.1.3 \ No newline at end of file diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 9116db182c2d..c2739fc90585 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -65,7 +65,7 @@ this_dir="$(cd $(dirname $0) && pwd)" script_name="$(basename $0)" repo_root=$(cd "${this_dir}" && find_git_dir_parent) manylinux_docker_image="${manylinux_docker_image:-$(uname -m | awk '{print ($1 == "aarch64") ? "quay.io/pypa/manylinux_2_28_aarch64" : "ghcr.io/nod-ai/manylinux_x86_64:main" }')}" -python_versions="${override_python_versions:-cp39-cp39 cp310-cp310 cp311-cp311}" +python_versions="${override_python_versions:-cp311-cp311}" output_dir="${output_dir:-${this_dir}/wheelhouse}" packages="${packages:-iree-runtime iree-compiler}" package_suffix="${package_suffix:-}" @@ -157,10 +157,12 @@ function build_iree_runtime() { export IREE_RUNTIME_BUILD_TRACY=ON # We install the needed build deps below for the tools. export IREE_RUNTIME_BUILD_TRACY_TOOLS=ON + export IREE_EXTERNAL_HAL_DRIVERS="rocm" build_wheel runtime/ } function build_iree_compiler() { + export IREE_TARGET_BACKEND_ROCM=ON build_wheel compiler/ } diff --git a/build_tools/python_deploy/build_windows_packages.ps1 b/build_tools/python_deploy/build_windows_packages.ps1 index 43906f8800f7..a8fbd3a9aadf 100644 --- a/build_tools/python_deploy/build_windows_packages.ps1 +++ b/build_tools/python_deploy/build_windows_packages.ps1 @@ -67,11 +67,13 @@ function run() { function build_iree_runtime() { param($python_version) $env:IREE_HAL_DRIVER_VULKAN = "ON" + $env:IREE_EXTERNAL_HAL_DRIVERS = "rocm" & py -${python_version} -m pip wheel -v -w $output_dir $repo_root/runtime/ } function build_iree_compiler() { param($python_version) + $env:IREE_TARGET_BACKEND_ROCM= "ON" py -${python_version} -m pip wheel -v -w $output_dir $repo_root/compiler/ } diff --git a/build_tools/scripts/get_latest_green.sh b/build_tools/scripts/get_latest_green.sh index 979acb2b6ec0..ea08d08125ce 100755 --- a/build_tools/scripts/get_latest_green.sh +++ b/build_tools/scripts/get_latest_green.sh @@ -36,17 +36,6 @@ function get_latest_green() { local query_string="$(IFS="&" ; echo "${query_params[*]}")" local all_passing="true" - for workflow in "${REQUIRED_WORKFLOWS[@]}"; do - local successful_run_count="$(\ - gh api --jq '.total_count' \ - "/repos/openxla/iree/actions/workflows/${workflow}/runs?${query_string}" \ - )" - # Any successful run of the workflow (including reruns) is OK. - if (( successful_run_count==0 )); then - all_passing="false" - break - fi - done if [[ "${all_passing}" == true ]]; then echo "${commit}" return 0 diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index e47ff50ff964..6817758ae010 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -210,6 +210,13 @@ add_iree_compiler_busybox_tool( IREECompileTool.c ) +add_iree_compiler_busybox_tool( + IREECompilerIREEOptTool + OUTPUT_NAME iree-opt + SRCS + IREEOptTool.c +) + if(TARGET lld) add_iree_compiler_busybox_tool( IREECompilerLldTool diff --git a/compiler/bindings/python/IREEOptTool.c b/compiler/bindings/python/IREEOptTool.c new file mode 100644 index 000000000000..5c3f24133251 --- /dev/null +++ b/compiler/bindings/python/IREEOptTool.c @@ -0,0 +1,9 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/tool_entry_points_api.h" + +int main(int argc, char **argv) { return ireeOptRunMain(argc, argv); } diff --git a/compiler/bindings/python/iree/compiler/tools/binaries.py b/compiler/bindings/python/iree/compiler/tools/binaries.py index 7b8e592bf297..0ef1e604ebb9 100644 --- a/compiler/bindings/python/iree/compiler/tools/binaries.py +++ b/compiler/bindings/python/iree/compiler/tools/binaries.py @@ -30,6 +30,7 @@ _BUILTIN_TOOLS = [ "iree-compile", + "iree-opt", "iree-lld", ] @@ -42,6 +43,7 @@ # options. "iree-compile": "iree.tools.core", "iree-lld": "iree.tools.core", + "iree-opt": "iree.tools.core", "iree-import-tflite": "iree.tools.tflite", "iree-import-tf": "iree.tools.tf", } diff --git a/compiler/bindings/python/iree/compiler/tools/core.py b/compiler/bindings/python/iree/compiler/tools/core.py index 4eb7bc66229a..e019d7d767fe 100644 --- a/compiler/bindings/python/iree/compiler/tools/core.py +++ b/compiler/bindings/python/iree/compiler/tools/core.py @@ -25,6 +25,8 @@ "CompilerOptions", "InputType", "OutputFormat", + "preprocess_file", + "preprocess_str", ] # Default testing backend for invoking the compiler. @@ -318,3 +320,160 @@ def query_available_targets(): target_backends = [target for target in target_backends if target] return target_backends + + +# Preprocessing for SHARK (for now simply exposes iree-opt) + + +def build_opt_command_line( + input_file: str, tfs: TempFileSaver, options: CompilerOptions +) -> List[str]: + """Builds a command line for applying specified patterns. + + Args: + input_file: The input file name. + tfs: TempFileSaver. + options: Compiler options. + Returns: + List of strings of command line. + """ + iree_opt = find_tool("iree-opt") + cl = [ + iree_opt, + input_file, + ] + + # Output file. + if options.output_file: + cl.append(f"-o={options.output_file}") + + # Tool paths. + lld_path = find_tool("iree-lld") + cl.append(f"--iree-llvmcpu-embedded-linker-path={lld_path}") + + crash_reproducer_path = tfs.alloc_optional( + "core-reproducer.mlir", export_as=options.crash_reproducer_path + ) + if crash_reproducer_path: + cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") + + cl.extend(options.extra_args) + print(cl) + return cl + + +def build_opt_command_line( + input_file: str, tfs: TempFileSaver, options: CompilerOptions +) -> List[str]: + """Builds a command line for applying specified patterns. + + Args: + input_file: The input file name. + tfs: TempFileSaver. + options: Compiler options. + Returns: + List of strings of command line. + """ + iree_opt = find_tool("iree-opt") + cl = [ + iree_opt, + input_file, + ] + + # Output file. + if options.output_file: + cl.append(f"-o={options.output_file}") + + # Tool paths. + lld_path = find_tool("iree-lld") + cl.append(f"--iree-llvm-embedded-linker-path={lld_path}") + + crash_reproducer_path = tfs.alloc_optional( + "core-reproducer.mlir", export_as=options.crash_reproducer_path + ) + if crash_reproducer_path: + cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") + + cl.extend(options.extra_args) + print(cl) + return cl + + +def preprocess_file(input_file: str, **kwargs): + """Invokes iree-opt on an input file. + + Args: + input_file: File containing MLIR assembly to compile. + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_opt_command_line(input_file, tfs, options) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl) + if options.output_file: + return None + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result + + +def preprocess_str(input_str: Union[str, bytes], **kwargs): + """Invokes the IREE compiler with an input string. + + Args: + input_str: MLIR assembly to parse/compile (str or bytes). + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + retained_input_file = tfs.alloc_optional("core-input.mlir") + if retained_input_file: + with open( + retained_input_file, "wt" if isinstance(input_str, str) else "wb" + ) as f: + f.write(input_str) + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_opt_command_line("-", tfs, options) + input_bytes = ( + input_str.encode("utf-8") if isinstance(input_str, str) else input_str + ) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl, immediate_input=input_bytes) + if options.output_file: + return None + + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result diff --git a/compiler/setup.py b/compiler/setup.py index 01a347d296b9..df229d993143 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -251,14 +251,16 @@ def prepare_installation(): "-GNinja", "--log-level=VERBOSE", "-DIREE_BUILD_PYTHON_BINDINGS=ON", - "-DIREE_BUILD_SAMPLES=OFF", - "-DIREE_BUILD_TESTS=OFF", # Disable .so.0 style symlinking. Python wheels don't preserve links, # so this ~doubles the binary size if not disabled (yikes!). "-DCMAKE_PLATFORM_NO_VERSIONED_SONAME=ON", + "-DIREE_BUILD_TESTS=OFF", + "-DIREE_BUILD_SAMPLES=OFF", "-DPython3_EXECUTABLE={}".format(sys.executable), "-DCMAKE_BUILD_TYPE={}".format(cfg), # TODO(scotttodd): include IREE_TARGET_BACKEND_WEBGPU here (and in env) + get_env_cmake_option("IREE_TARGET_BACKEND_ROCM"), + get_env_cmake_option("IREE_TARGET_BACKEND_OPENCL_SPIRV"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), get_env_cmake_option("IREE_TARGET_BACKEND_ROCM", "ON"), get_env_cmake_option("IREE_ENABLE_LLD", "OFF"), diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp index 88d3474402d4..c1685ce0d43b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp @@ -736,6 +736,25 @@ struct RemoveDynamicCastOp final : public OpRewritePattern { } }; +/// Removes memref.cast that turns dynamic shapes into static shapes. +struct RemoveStaticCastOp final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CastOp castOp, + PatternRewriter &rewriter) const override { + auto srcType = castOp.getSource().getType().cast(); + auto dstType = castOp.getType().cast(); + // Restrict to the cases we generate in this pass--1-D static shape to 1-D + // dynamic shape. + if (srcType.getRank() == 1 && !srcType.hasStaticShape() && + dstType.getRank() == 1 && dstType.hasStaticShape()) { + rewriter.replaceOp(castOp, castOp.getSource()); + return success(); + } + return failure(); + } +}; + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// @@ -894,6 +913,7 @@ struct FlattenMemRefSubspanPass memref::AllocaOp::getCanonicalizationPatterns(cleanupPatterns, context); memref::SubViewOp::getCanonicalizationPatterns(cleanupPatterns, context); cleanupPatterns.add(context); + cleanupPatterns.add(context); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(cleanupPatterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp index 2c71aa9444fa..fcfec190ab41 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp @@ -195,7 +195,7 @@ class VectorReductionToGPUPass bool expandSubgroupReduction, std::function getWarpSize) : expandSubgroupReduction(expandSubgroupReduction), - getWarpSize(getWarpSize) {} + getWarpSize(std::move(getWarpSize)) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert>, vector<1xf32> // CHECK: vector.transfer_write {{.*}} : vector<1xf32>, memref<128x32xf32> // CHECK: return + + +// ----- + +// Check that we multi-row matvec gets distributed across subgoroup threads. + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}> +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @multirow { + hal.executable.variant @rocm target(#executable_target_rocm_hsaco_fb) { + hal.executable.export @multirow layout(#pipeline_layout) attributes { + workgroup_size = [64 : index, 1 : index, 1 : index] + } + builtin.module { + func.func @multirow() { + %cst = arith.constant dense<0.000000e+00> : vector<4x512xf16> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf16> + %c4096 = arith.constant 4096 : index + %c512 = arith.constant 512 : index + %cst_1 = arith.constant 0.000000e+00 : f16 + %id = gpu.thread_id x + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x4096xf16, #hal.descriptor_type> + memref.assume_alignment %0, 64 : memref<1x4096xf16, #hal.descriptor_type> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<32000x4096xf16, #hal.descriptor_type> + memref.assume_alignment %1, 64 : memref<32000x4096xf16, #hal.descriptor_type> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<1x32000xf16, #hal.descriptor_type> + memref.assume_alignment %2, 64 : memref<1x32000xf16, #hal.descriptor_type> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x] + %4 = scf.for %arg0 = %c0 to %c4096 step %c512 iter_args(%arg1 = %cst) -> (vector<4x512xf16>) { + %8 = vector.transfer_read %0[%c0, %arg0], %cst_1 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (0, d1)>} : memref<1x4096xf16, #hal.descriptor_type>, vector<4x512xf16> + %9 = vector.transfer_read %1[%3, %arg0], %cst_1 {in_bounds = [true, true]} : memref<32000x4096xf16, #hal.descriptor_type>, vector<4x512xf16> + %10 = arith.mulf %8, %9 : vector<4x512xf16> + %11 = arith.addf %arg1, %10 : vector<4x512xf16> + scf.yield %11 : vector<4x512xf16> + } + %5 = vector.broadcast %4 : vector<4x512xf16> to vector<1x4x512xf16> + %6 = vector.multi_reduction , %5, %cst_0 [2] : vector<1x4x512xf16> to vector<1x4xf16> + %7 = vector.extract %6[0] : vector<4xf16> from vector<1x4xf16> + vector.transfer_write %7, %2[%c0, %3] {in_bounds = [true]} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type> + return + } + } + } +} + +// CHECK-LABEL: func.func @multirow() { +// CHECK: scf.for {{.*}} -> (vector<4x8xf16>) { +// CHECK: vector.transfer_read {{.*}} : memref<32000x4096xf16, #hal.descriptor_type>, vector<4x8xf16> +// CHECK: vector.transfer_read {{.*}} : memref<1x4096xf16, #hal.descriptor_type>, vector<4x8xf16> +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<4x8xf16> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4x8xf16> +// CHECK: } +// CHECK: gpu.shuffle xor +// CHECK: scf.if {{.*}} { +// CHECK: vector.transfer_write {{.*}} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type> +// CHECK: } +// CHECK-NEXT: return diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index d9111c7f9aa7..097bb2b0dfbc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -51,6 +51,84 @@ class TransposeUnitDimToShapeCast } }; +// TODO: Move this upstream +// Hoists a vector.bitcast op to the output of the enclosing scf.if +// +// This transforms IR like: +// %0 = scf.if %1 -> (vector<16xi8>) { +// %2 = memref.load %4[%c0] : memref> +// %3 = vector.bitcast %2 : vector<4xi32> to vector<16xi8> +// scf.yield %3 : vector<16xi8> +// } else { +// scf.yield %cst : vector<16xi8> +// } +// Into: +// %0 = scf.if %1 -> (vector<4xi32>) { +// %2 = memref.load %4[%c0] : memref> +// scf.yield %2 : vector<4xi32> +// } else { +// %3 = vector.bitcast %cst : vector<16xi8> to vector<4xi32> +// scf.yield %0 : vector<4xi32> +// } +// %3 = vector.bitcast %0 : vector<4xi32> to vector<16xi8> +struct BubbleUpBitCastOfScfIf : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + // Bail on more than one result for now. + scf::YieldOp thenYield = ifOp.thenYield(); + if (!thenYield || thenYield.getNumOperands() != 1) + return failure(); + auto bitcastOp = thenYield.getOperand(0).getDefiningOp(); + // Bail out if no bitcast on the if then statement. + if (!bitcastOp) + return failure(); + + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + // Skip 0-D vector. + if (castSrcType.getRank() == 0) + return failure(); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to more elements; + if (castSrcLastDim > castDstLastDim) + return failure(); + + Location loc = ifOp.getLoc(); + + auto bitcastedIfOp = + rewriter.create(loc, castSrcType, ifOp.getCondition()); + bitcastedIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + bitcastedIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + + scf::YieldOp newThenYield = bitcastedIfOp.thenYield(); + auto newBitcastOp = + newThenYield.getOperand(0).getDefiningOp(); + + newThenYield.setOperand(0, newBitcastOp.getSource()); + + auto newBitcast = rewriter.create( + loc, castDstType, bitcastedIfOp.getResult(0)); + + scf::YieldOp elseYield = bitcastedIfOp.elseYield(); + if (elseYield) { + OpBuilder::InsertionGuard elseGuard(rewriter); + rewriter.setInsertionPoint(elseYield); + + Value yieldSrc = elseYield.getOperand(0); + auto elseBitcast = + rewriter.create(loc, castSrcType, yieldSrc); + elseYield.setOperand(0, elseBitcast); + } + rewriter.replaceOp(ifOp, newBitcast); + return success(); + } +}; + static void loopInvariantCodeMotion(func::FuncOp funcOp) { // Walk through all loops in a function in innermost-loop-first order. This // way, we first LICM from the inner loop, and place the ops in @@ -99,6 +177,7 @@ struct OptimizeVectorTransferPass { RewritePatternSet patterns(&getContext()); vector::populateBubbleVectorBitCastOpPatterns(patterns); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 8161510f8627..eba4bef21317 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -53,8 +53,11 @@ iree_compiler_cc_library( "KernelDispatch.cpp", "LLVMCPUAssignConstantOrdinals.cpp", "LLVMCPUAssignImportOrdinals.cpp", + "LLVMCPUBreakDownSubbyteExtend.cpp", "LLVMCPUCheckIRBeforeLLVMConversion.cpp", "LLVMCPUEmitVectorizationRemarks.cpp", + "LLVMCPUFoldMemRefAliasOps.cpp", + "LLVMCPUFoldVectorContractUnitDims.cpp", "LLVMCPULinkExecutables.cpp", "LLVMCPULowerExecutableTarget.cpp", "LLVMCPUMmt4dVectorLowering.cpp", @@ -68,6 +71,7 @@ iree_compiler_cc_library( "LLVMCPUUnfuseFMAOps.cpp", "LLVMCPUVectorLowering.cpp", "Passes.cpp", + "SetSpecialTilingConfigs.cpp", "TargetMLTransformInfo.cpp", "Utils.cpp", "VectorContractCustomKernels.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 1250c4b17b06..e280dc0e3208 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -54,8 +54,11 @@ iree_cc_library( "KernelDispatch.cpp" "LLVMCPUAssignConstantOrdinals.cpp" "LLVMCPUAssignImportOrdinals.cpp" + "LLVMCPUBreakDownSubbyteExtend.cpp" "LLVMCPUCheckIRBeforeLLVMConversion.cpp" "LLVMCPUEmitVectorizationRemarks.cpp" + "LLVMCPUFoldMemRefAliasOps.cpp" + "LLVMCPUFoldVectorContractUnitDims.cpp" "LLVMCPULinkExecutables.cpp" "LLVMCPULowerExecutableTarget.cpp" "LLVMCPUMmt4dVectorLowering.cpp" @@ -69,6 +72,7 @@ iree_cc_library( "LLVMCPUUnfuseFMAOps.cpp" "LLVMCPUVectorLowering.cpp" "Passes.cpp" + "SetSpecialTilingConfigs.cpp" "TargetMLTransformInfo.cpp" "Utils.cpp" "VectorContractCustomKernels.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp new file mode 100644 index 000000000000..6d9b90384145 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp @@ -0,0 +1,387 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-breakdown-subbyte-extend" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { +namespace { + +template +static Value shuffleMaskShift(PatternRewriter &rewriter, Location loc, + SmallVector shuffleInputs, + int64_t srcBitWidth, int64_t vectorSize) { + auto shuffleInType = llvm::cast(shuffleInputs[0].getType()); + auto shuffleResultType = + VectorType::get({vectorSize}, shuffleInType.getElementType()); + int64_t dstBitWidth = shuffleInType.getElementTypeBitWidth(); + T maskBase = (1u << srcBitWidth) - 1; + + SmallVector maskArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + maskArray[elemNum] = maskBase << (elemNum * srcBitWidth % dstBitWidth); + } + auto maskVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, maskArray)); + LDBG("maskVals: " << maskVals); + SmallVector shruiArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + shruiArray[elemNum] = elemNum * srcBitWidth % dstBitWidth; + } + auto shruiVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + + int64_t dstSize = vectorSize * shuffleInputs.size(); + auto newVectorType = + VectorType::get({dstSize}, shuffleResultType.getElementType()); + Value newVector = rewriter.create( + loc, newVectorType, rewriter.getZeroAttr(newVectorType)); + + for (auto shuffleIn : llvm::enumerate(shuffleInputs)) { + SmallVector shuffleArray(vectorSize); + for (int64_t elemNum = 0; elemNum < vectorSize; elemNum++) { + shuffleArray[elemNum] = + elemNum / (vectorSize / shuffleInType.getNumElements()); + } + Value shuffleResult = rewriter.create( + loc, shuffleIn.value(), shuffleIn.value(), shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value andResult = + rewriter.create(loc, shuffleResult, maskVals); + LDBG("andResult: " << andResult); + + Value shruiResult = + rewriter.create(loc, andResult, shruiVals); + LDBG("shruiResult: " << shruiResult); + + int64_t offset = shuffleIn.index() * vectorSize; + newVector = rewriter.create( + loc, shruiResult, newVector, offset, 1); + } + return newVector; +} + +static std::optional> +getLoadsForExtend(arith::ExtUIOp extOp) { + Value extSource = extOp.getIn(); + auto shapeCastOp = extSource.getDefiningOp(); + if (!shapeCastOp) { + return std::nullopt; + } + Value shapeCastSource = shapeCastOp.getSource(); + auto insertOp = shapeCastSource.getDefiningOp(); + if (!insertOp) { + return std::nullopt; + } + SmallVector loads; + while (insertOp) { + Value insert = insertOp.getSource(); + auto insertShapeCastOp = insert.getDefiningOp(); + if (!insertShapeCastOp) { + return std::nullopt; + } + auto loadOp = insertShapeCastOp.getSource().getDefiningOp(); + if (!loadOp) { + return std::nullopt; + } + loads.push_back(loadOp.getResult()); + insertOp = insertOp.getDest().getDefiningOp(); + } + return loads; +} + +struct BreakDownSubbyteExtend final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + + SmallVector sources{extOp.getIn()}; + if (auto loads = getLoadsForExtend(extOp)) { + sources = *loads; + } + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(dstElemBitwidth) || srcElemBitwidth != 4) + return failure(); + + if (dstElemBitwidth != 32 && dstElemBitwidth != 16) { + return failure(); + } + + int64_t vectorSizeBits = 512; + int64_t vectorSize = vectorSizeBits / dstElemBitwidth; + int64_t shuffleInputSizeBits = vectorSize * srcElemBitwidth; + int64_t shuffleInputSize = shuffleInputSizeBits / dstElemBitwidth; + auto shuffleInputType = + VectorType::get({shuffleInputSize}, extuiDstType.getElementType()); + Value shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + SmallVector shuffleInputs; + + for (int sourceIdx = 0; sourceIdx < sources.size(); sourceIdx++) { + Value source = sources[sourceIdx]; + VectorType sourceType = llvm::cast(source.getType()); + SmallVector sourceShape(sourceType.getShape()); + int64_t innerSize = sourceShape.back(); + if (!llvm::isPowerOf2_64(innerSize)) { + return failure(); + } + for (int64_t i = 0; i < sourceType.getNumElements() / innerSize; i++) { + SmallVector indices; + int64_t numElems = i; + SmallVector sourceOuterShape(sourceShape.begin(), + sourceShape.end() - 1); + for (int64_t size : llvm::reverse(sourceOuterShape)) { + indices.push_back(numElems % size); + numElems /= size; + } + std::reverse(indices.begin(), indices.end()); + + Value innerSlice; + if (indices.size()) { + innerSlice = rewriter.create(extOp.getLoc(), + source, indices); + } else { + innerSlice = source; + } + VectorType innerSliceType = + llvm::cast(innerSlice.getType()); + int64_t numExtractedBits = + innerSliceType.getNumElements() * srcElemBitwidth; + if (numExtractedBits / dstElemBitwidth < 1) { + LDBG("extract not big enough: " << numExtractedBits / + dstElemBitwidth); + return failure(); + } + auto bitCastType = VectorType::get({numExtractedBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, innerSlice); + LDBG("innerSlice: " << innerSlice); + // LDBG("bitCastResult: " << bitCastResult); + + if (numExtractedBits >= shuffleInputSizeBits) { + for (int64_t extractOffset = 0; + extractOffset < numExtractedBits / dstElemBitwidth; + extractOffset += shuffleInputSize) { + Value extractedSlice = + rewriter.create( + extOp.getLoc(), bitCastResult, extractOffset, + shuffleInputSize, 1); + shuffleInputs.push_back(extractedSlice); + LDBG("extractedSlice: " << extractedSlice); + // vector = + // rewriter.create(extOp.getLoc(), + // extractedSlice, vector, SmallVector{offset}, + // SmallVector{1}); + } + } else { + int64_t offset = + i * numExtractedBits / dstElemBitwidth % shuffleInputSize; + shuffleInput = rewriter.create( + extOp.getLoc(), bitCastResult, shuffleInput, + SmallVector{offset}, SmallVector{1}); + if (offset + numExtractedBits / dstElemBitwidth == shuffleInputSize) { + shuffleInputs.push_back(shuffleInput); + shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + } + } + } + } + + Value newVector; + if (dstElemBitwidth == 32) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } else if (dstElemBitwidth == 16) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } + rewriter.replaceOpWithNewOp(extOp, extuiDstType, + newVector); + + return success(); + } +}; + +struct BreakDownSubbyteExtendFlatten final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + LDBG("extuiSrcType: " << extuiSrcType); + LDBG("extuiDstType: " << extuiDstType); + + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(extuiSrcType.getNumElements())) + return failure(); + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + LDBG("srcElemBitwidth: " << srcElemBitwidth); + LDBG("dstElemBitwidth: " << dstElemBitwidth); + + int64_t numBits = srcElemBitwidth * extuiSrcType.getNumElements(); + if (numBits / dstElemBitwidth < 1) { + return failure(); + } + + VectorType flattenedType = VectorType::get({extuiSrcType.getNumElements()}, + extuiSrcType.getElementType()); + Value shapeCastFlatten = rewriter.create( + extOp.getLoc(), flattenedType, extOp.getIn()); + + auto bitCastType = VectorType::get({numBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, shapeCastFlatten); + LDBG("bitCastResult: " << bitCastResult); + + SmallVector shuffleArray(extuiDstType.getNumElements()); + for (int64_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shuffleArray[elemNum] = elemNum / (extuiDstType.getNumElements() / + bitCastType.getNumElements()); + } + + Value shuffleResult = rewriter.create( + extOp.getLoc(), bitCastResult, bitCastResult, shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value shapeCastUnflatten = rewriter.create( + extOp.getLoc(), extuiDstType, shuffleResult); + Value maskVals, shruiVals; + if (dstElemBitwidth == 32) { + int32_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else if (dstElemBitwidth == 16) { + int16_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else { + return failure(); + } + + Value andResult = rewriter.create( + extOp.getLoc(), shapeCastUnflatten, maskVals); + LDBG("andResult: " << andResult); + + rewriter.replaceOpWithNewOp(extOp, andResult, shruiVals); + + return success(); + } +}; + +struct LLVMCPUBreakDownSubbyteExtendPass final + : public LLVMCPUBreakDownSubbyteExtendBase< + LLVMCPUBreakDownSubbyteExtendPass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + { + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + // For the case when the innermost dimension of the src type is too small to + // fill a single element of the dst type. + // { + // RewritePatternSet patterns(context); + // patterns.add(context); + // vector::populateVectorShapeCastLoweringPatterns(patterns); + // if (failed(applyPatternsAndFoldGreedily(getOperation(), + // std::move(patterns)))) { + // return signalPassFailure(); + // } + // } + } +}; + +} // namespace + +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass() { + return std::make_unique(); +} + +void populateLLVMCPUBreakDownSubbyteExtendPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp new file mode 100644 index 000000000000..fc8c40dfb2e2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp @@ -0,0 +1,283 @@ +//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass folds loading/storing from/to subview ops into +// loading/storing from/to the original memref. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-memref-alias-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir { +namespace iree_compiler { + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Merges expand_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfExpandShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges collapse_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfCollapseShapeOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +static SmallVector +calculateExpandedAccessIndices(AffineMap affineMap, + const SmallVector &indices, Location loc, + PatternRewriter &rewriter) { + SmallVector indicesOfr(llvm::to_vector( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }))); + SmallVector expandedIndices; + for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, affineMap.getSubMap({i}), indicesOfr); + expandedIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return expandedIndices; +} + +static LogicalResult +resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, + memref::ExpandShapeOp expandShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + // The below implementation uses computeSuffixProduct method, which only + // allows int64_t values (i.e., static shape). Bail out if it has dynamic + // shapes. + if (!expandShapeOp.getResultType().hasStaticShape()) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + for (ArrayRef groups : expandShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + int64_t groupSize = groups.size(); + + // Construct the expression for the index value w.r.t to expand shape op + // source corresponding the indices wrt to expand shape op result. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + SmallVector dims(groupSize); + bindDimsList(ctx, MutableArrayRef{dims}); + AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); + + /// Apply permutation and create AffineApplyOp. + SmallVector dynamicIndices(groupSize); + for (int64_t i = 0; i < groupSize; i++) + dynamicIndices[i] = indices[groups[i]]; + + // Creating maximally folded and composd affine.apply composes better with + // other transformations without interleaving canonicalization passes. + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/groupSize, + /*numSymbols=*/0, srcIndexExpr), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return success(); +} + +static LogicalResult +resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, + memref::CollapseShapeOp collapseShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + int64_t cnt = 0; + SmallVector tmp(indices.size()); + SmallVector dynamicIndices; + for (ArrayRef groups : collapseShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + dynamicIndices.push_back(indices[cnt++]); + int64_t groupSize = groups.size(); + + // Calculate suffix product for all collapse op source dimension sizes. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + + // Derive the index values along all dimensions of the source corresponding + // to the index wrt to collapsed shape op output. + auto d0 = rewriter.getAffineDimExpr(0); + SmallVector delinearizingExprs = delinearize(d0, suffixProduct); + + // Construct the AffineApplyOp for each delinearizingExpr. + for (int64_t i = 0; i < groupSize; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, + delinearizingExprs[i]), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + dynamicIndices.clear(); + } + if (collapseShapeOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + int64_t srcRank = + cast(collapseShapeOp.getViewSource().getType()).getRank(); + for (int64_t i = 0; i < srcRank; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, zeroAffineMap, dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + } + return success(); +} + +/// Helpers to access the memref operand for each op. +template +static Value getMemRefOperand(LoadOrStoreOpTy op) { + return op.getMemref(); +} + +static Value getMemRefOperand(vector::TransferReadOp op) { + return op.getSource(); +} + +static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } + +template +LogicalResult LLVMCPULoadOpOfExpandShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto expandShapeOp = + getMemRefOperand(loadOp).template getDefiningOp(); + + if (!expandShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesExpandShape( + loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), expandShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult LLVMCPULoadOpOfCollapseShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto collapseShapeOp = getMemRefOperand(loadOp) + .template getDefiningOp(); + + if (!collapseShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesCollapseShape( + loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), collapseShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +void populateLLVMCPUFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { + patterns.add, + LLVMCPULoadOpOfCollapseShapeOpFolder>( + patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct LLVMCPUFoldMemRefAliasOpsPass final + : public LLVMCPUFoldMemRefAliasOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void LLVMCPUFoldMemRefAliasOpsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateFoldMemRefAliasOpPatterns(patterns); + populateLLVMCPUFoldMemRefAliasOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir \ No newline at end of file diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp new file mode 100644 index 000000000000..1096693c693d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp @@ -0,0 +1,354 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- LLVMCPUFoldVectorContractUnitDims.cpp - Pass to fold unit dims of +// vector.contract ops -===// +// +// Patterns to fold away unit dimensions on `vector.contract` ops +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-unit-reduction-dims" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { + +// Given a `vector.contract` op and a set of indices to fold, this op rewrites +// the `vector.contract` op with surrounding `vector.shape_cast` ops to fold +// away the indicated indices. +static FailureOr +dropFoldableUnitIndices(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector foldIndices) { + SmallVector contractShape = *contractOp.getShapeForUnroll(); + SmallVector iteratorTypes = + contractOp.getIteratorTypesArray(); + auto indexingMaps = contractOp.getIndexingMapsArray(); + SmallVector> dstShapes; + SmallVector> dstExprs; + SmallVector inputs( + {contractOp.getLhs(), contractOp.getRhs(), contractOp.getAcc()}); + llvm::SetVector foldableDims; + for (int64_t dim : foldIndices) + foldableDims.insert(dim); + + for (AffineMap map : indexingMaps) { + SmallVector dstShape; + SmallVector dstExpr; + for (const auto &expr : enumerate(map.getResults())) { + if (auto dimExpr = llvm::dyn_cast(expr.value())) { + if (!foldableDims.contains(dimExpr.getPosition())) { + dstShape.push_back(contractShape[dimExpr.getPosition()]); + unsigned numSkipped = 0; + for (int64_t ind : foldIndices) { + if (dimExpr.getPosition() > ind) { + numSkipped++; + } + } + dstExpr.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition() - numSkipped)); + } + } else { + return failure(); + } + } + dstShapes.push_back(dstShape); + dstExprs.push_back(dstExpr); + } + + SmallVector newInputs; + SmallVector newIndexingMaps; + SmallVector newIteratorTypes; + for (auto iter : enumerate(iteratorTypes)) { + if (!foldableDims.contains(iter.index())) { + newIteratorTypes.push_back(iter.value()); + } + } + + for (int i = 0; i < 3; i++) { + // Shape unchanged + if (dstShapes[i].size() == indexingMaps[i].getResults().size()) { + newInputs.push_back(inputs[i]); + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newIndexingMaps.push_back(newIndexingMap); + continue; + } + if (dstShapes[i].size() == 0) { + return failure(); + } + VectorType inputVecType = llvm::cast(inputs[i].getType()); + VectorType dstType = + VectorType::get(dstShapes[i], inputVecType.getElementType()); + + Value result; + auto extsiop = inputs[i].getDefiningOp(); + auto extuiop = inputs[i].getDefiningOp(); + if (!extsiop && !extuiop) { + result = rewriter.create(contractOp.getLoc(), + dstType, inputs[i]); + } else { + Value extIn = extsiop ? extsiop.getIn() : extuiop.getIn(); + VectorType extInType = llvm::dyn_cast(extIn.getType()); + VectorType shapeCastOutType = + VectorType::get(dstType.getShape(), extInType.getElementType()); + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), shapeCastOutType, extIn); + result = extsiop ? rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult() + : rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult(); + } + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newInputs.push_back(result); + newIndexingMaps.push_back(newIndexingMap); + } + auto newContract = + rewriter + .create( + contractOp.getLoc(), newInputs[0], newInputs[1], newInputs[2], + rewriter.getAffineMapArrayAttr(newIndexingMaps), + rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( + newIteratorTypes, + [&](vector::IteratorType t) -> mlir::Attribute { + return vector::IteratorTypeAttr::get(rewriter.getContext(), + t); + })))) + .getResult(); + return newContract; +} + +// This pattern matches on a `vector.contract` op with unit size dimensions, and +// folds these dimensions away +class DropVectorContractUnitDims final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + LDBG("vector.contract op:\n" << contractOp); + VectorType outputType = + llvm::dyn_cast(contractOp.getAcc().getType()); + if (!outputType) { + return failure(); + } + + auto parentOp = contractOp->getParentOfType(); + if (parentOp) { + return failure(); + } + + auto iteratorTypes = contractOp.getIteratorTypesArray(); + SmallVector contractDims = *contractOp.getShapeForUnroll(); + unsigned numParallel = 0; + unsigned numReduction = 0; + SmallVector unitParallelDims; + SmallVector unitReductionDims; + SmallVector foldableDims; + for (auto size : enumerate(contractDims)) { + if (iteratorTypes[size.index()] == vector::IteratorType::parallel) { + numParallel++; + if (size.value() == 1) { + unitParallelDims.push_back(size.index()); + } + } else { + numReduction++; + if (size.value() == 1) { + unitReductionDims.push_back(size.index()); + } + } + } + if (numReduction && numReduction == unitReductionDims.size()) { + foldableDims.append(unitReductionDims.begin(), + unitReductionDims.end() - 1); + } else { + foldableDims.append(unitReductionDims.begin(), unitReductionDims.end()); + } + if (numParallel && numParallel == unitParallelDims.size()) { + foldableDims.append(unitParallelDims.begin() + 1, unitParallelDims.end()); + } else { + foldableDims.append(unitParallelDims.begin(), unitParallelDims.end()); + } + if (!foldableDims.size()) { + return failure(); + } + + FailureOr maybeNewContract = + dropFoldableUnitIndices(rewriter, contractOp, foldableDims); + if (failed(maybeNewContract)) { + return failure(); + } + Value newContract = maybeNewContract.value(); + LDBG("Replaced vector.contract:\n" << newContract); + + VectorType newOutputType = + llvm::dyn_cast(newContract.getType()); + if (outputType != newOutputType) { + // Reshape output of new vector.contract if needed + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), outputType, newContract); + rewriter.replaceOp(contractOp, shapeCastResult); + } else { + rewriter.replaceOp(contractOp, newContract); + } + + return success(); + } +}; + +// This pattern matches on a sequence of +// `vector.shape_cast->vector.contract->vector.shape_cast` within an `scf.for` +// op, where the shape cast ops are casting an argument of the `scf.for` op and +// the yielded result of the `scf.for` op. Once matched, the `vector.shape_cast` +// ops are hoisted out of the `scf.for` op. +class HoistShapeCastOutOfSCFFor final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + LDBG("forOp:\n" << forOp); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + std::optional> + hoistableShapeCast = std::nullopt; + int initArgIdx; + for (Value result : yieldOp.getOperation()->getOperands()) { + auto outputShapeCastOp = result.getDefiningOp(); + if (!outputShapeCastOp) { + continue; + } + LDBG("outputShapeCastOp:\n" << outputShapeCastOp); + auto contractOp = + outputShapeCastOp.getSource().getDefiningOp(); + if (!contractOp) { + continue; + } + LDBG("contractOp:\n" << contractOp); + Value acc = contractOp.getAcc(); + auto inputShapeCastOp = acc.getDefiningOp(); + if (!inputShapeCastOp) { + continue; + } + LDBG("inputShapeCastOp:\n" << inputShapeCastOp); + Value input = inputShapeCastOp.getSource(); + auto blockArg = dyn_cast(input); + if (!blockArg) { + continue; + } + LDBG("blockArg:\n" << blockArg); + hoistableShapeCast = std::make_pair(inputShapeCastOp, outputShapeCastOp); + initArgIdx = blockArg.getArgNumber() - 1; + } + + if (!hoistableShapeCast) { + return failure(); + } + vector::ShapeCastOp inSC = hoistableShapeCast->first; + vector::ShapeCastOp outSC = hoistableShapeCast->second; + SmallVector forOpInitArgs = forOp.getInitArgs(); + Value source = forOpInitArgs[initArgIdx]; + Value sourceSC = + rewriter + .create(forOp.getLoc(), inSC.getType(), source) + .getResult(); + forOpInitArgs[initArgIdx] = sourceSC; + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), forOpInitArgs); + LDBG("newForOp:\n" << newForOp); + rewriter.mergeBlocks(forOp.getBody(), newForOp.getBody(), + newForOp.getBody()->getArguments()); + auto newYieldOp = cast(newForOp.getBody()->getTerminator()); + LDBG("newYieldOp:\n" << newYieldOp); + SmallVector newForOpResults = + newYieldOp.getOperation()->getOperands(); + int contractResultIndex; + for (auto result : llvm::enumerate(newForOpResults)) { + if (result.value() == outSC.getResult()) { + newForOpResults[result.index()] = outSC.getSource(); + contractResultIndex = result.index(); + } + } + rewriter.updateRootInPlace(newYieldOp, [&]() { + newYieldOp.getOperation()->setOperands(newForOpResults); + }); + LDBG("newForOp with body:\n" << newForOp); + SmallVector newResults = newForOp.getResults(); + Value hoistedOutputShapeCast = + rewriter + .create(forOp.getLoc(), outSC.getType(), + newResults[contractResultIndex]) + .getResult(); + LDBG("hoistedOutputShapeCast:\n" << hoistedOutputShapeCast); + newResults[contractResultIndex] = hoistedOutputShapeCast; + rewriter.replaceOp(forOp, newResults); + + return success(); + } +}; + +namespace { +struct LLVMCPUFoldVectorContractUnitDimsPass + : public LLVMCPUFoldVectorContractUnitDimsBase< + LLVMCPUFoldVectorContractUnitDimsPass> { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; +} // namespace + +void LLVMCPUFoldVectorContractUnitDimsPass::runOnOperation() { + Operation *funcOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet foldUnitDimsPatterns(context); + foldUnitDimsPatterns + .add(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(foldUnitDimsPatterns)))) { + return signalPassFailure(); + } +} + +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass() { + return std::make_unique(); +} + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp index b0b36909f8fc..f00907c52a5e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp @@ -38,8 +38,9 @@ namespace { /// TODO: support named ops, numInputs > 1, and modify lastDim check below /// accordingly. If fpReductionReordering is not enabled by default, it must /// be an integer or index type to proceed to allow associative reordering. -LogicalResult splitReductionPrecondition(Operation *op, - bool fpReductionReordering) { +LogicalResult +splitReductionPrecondition(Operation *op, bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { linalg::LinalgOp linalgOp = cast(op); if (!linalgOp.hasTensorSemantics()) { @@ -63,7 +64,11 @@ LogicalResult splitReductionPrecondition(Operation *op, LLVM_DEBUG(llvm::dbgs() << "is not a generic op\n"); return failure(); } - if (linalgOp.getNumDpsInputs() != 1) { + if (enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() > 2) { + LLVM_DEBUG(llvm::dbgs() << "doesn't have at most 2 inputs\n"); + return failure(); + } + if (!enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() != 1) { LLVM_DEBUG(llvm::dbgs() << "doesn't have exactly 1 input\n"); return failure(); } @@ -102,8 +107,10 @@ LogicalResult splitReductionPrecondition(Operation *op, /// Converts an inner-reduction into outer reduction + inner-parallel dimension, /// followed by simple inner reduction. -LogicalResult splitReductionImpl(Operation *op, int64_t size, +LogicalResult splitReductionImpl(Operation *op, SmallVector tileSizes, + bool enableQuantizedMatmulReassociation, RewriterBase &rewriter) { + int64_t size = tileSizes.back(); IRRewriter::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(op); linalg::LinalgOp linalgOp = cast(op); @@ -119,8 +126,19 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, auto numLoops = linalgOp.getNumLoops(); // 1) Tile to extract a single vector-length array. - SmallVector tileSizesSVFirst(numLoops, - rewriter.getIndexAttr(1)); + SmallVector tileSizesSVFirst; + if (enableQuantizedMatmulReassociation) { + for (auto &s : tileSizes) { + if (!s) { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(1)); + } else { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(s)); + } + } + } else { + tileSizesSVFirst = + SmallVector(numLoops, rewriter.getIndexAttr(1)); + } tileSizesSVFirst[numLoops - 1] = rewriter.getIndexAttr(0); auto options = scf::SCFTilingOptions().setTileSizes(tileSizesSVFirst); FailureOr tileResFirst = scf::tileUsingSCFForOp( @@ -147,7 +165,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, rewriter.getIndexAttr(0)); // The reduction happens only in the penultimate dimension, which we now // tile. - tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + if (enableQuantizedMatmulReassociation) { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(2); + } else { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + } options = scf::SCFTilingOptions().setTileSizes(tileSizesSV); FailureOr tileRes = scf::tileUsingSCFForOp( rewriter, cast(splitRes->splitLinalgOp.getOperation()), @@ -164,8 +186,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, class LLVMCPUSplitReductionPass : public LLVMCPUSplitReductionBase { public: - LLVMCPUSplitReductionPass(bool fpReductionReordering) { + LLVMCPUSplitReductionPass(bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { this->enableFpReductionReordering = fpReductionReordering; + this->enableQuantizedMatmulReassociation = + enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -183,8 +208,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { funcOp.walk([&](linalg::GenericOp op) { candidates.push_back(op); }); for (auto genericOp : candidates) { LLVM_DEBUG(llvm::dbgs() << "candidate: " << genericOp << "\n"); - if (failed(splitReductionPrecondition(genericOp, - enableFpReductionReordering))) { + if (failed( + splitReductionPrecondition(genericOp, enableFpReductionReordering, + enableQuantizedMatmulReassociation))) { continue; } @@ -208,8 +234,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { "skip SplitReduction"); continue; } - int64_t size = reductionSizes.back(); - if (failed(splitReductionImpl(genericOp, size, rewriter))) { + if (failed(splitReductionImpl(genericOp, reductionSizes, + enableQuantizedMatmulReassociation, + rewriter))) { return signalPassFailure(); } } @@ -218,9 +245,10 @@ void LLVMCPUSplitReductionPass::runOnOperation() { } // namespace std::unique_ptr> -createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering) { +createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering, + const bool enableQuantizedMatmulReassociation) { return std::make_unique( - enableFpReductionReordering); + enableFpReductionReordering, enableQuantizedMatmulReassociation); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp index 65be5e78bd03..2668c15ec999 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp @@ -45,6 +45,8 @@ class LLVMCPUVectorLoweringPass LLVMCPUVectorLoweringPass(const LLVMCPUVectorLoweringPassOptions &options) { this->splitVectorTransfersTo = options.splitVectorTransfersTo; this->lowerVectorTransposeToAVX2 = options.lowerVectorTransposeToAVX2; + this->enableQuantizedMatmulReassociation = + options.enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -77,6 +79,26 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { .setVectorMultiReductionLowering(vectorMultiReductionLowering) .setVectorTransferSplit(vectorTransferSplit); + { + if (enableQuantizedMatmulReassociation) { + // Special-case vector.contract codegen paths. This needs to happen + // just before the generic vector ops lowerings. + RewritePatternSet patterns(ctx); + auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + populateVectorContractCustomKernelsPatterns(target, patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After custom kernel lowering for " + "vector.contract ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + { RewritePatternSet patterns(ctx); vector::populateVectorGatherLoweringPatterns(patterns); @@ -173,6 +195,23 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { llvm::dbgs() << "\n\n"; }); + // Break down subbyte `arith.extui` ops + { + if (enableQuantizedMatmulReassociation) { + RewritePatternSet patterns(&getContext()); + populateLLVMCPUBreakDownSubbyteExtendPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After breaking down subbyte extend ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + // 'vector.shape_cast' are very expensive operations that are even generated // by some of the lowerings above (e.g., transpose lowering). There are // chances to cancel them out if they are not lowered too early so we lower diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index b0fbf23e2cca..710a434cb6d2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -79,6 +79,13 @@ static llvm::cl::opt clInstrumentMemoryAccesses{ "instrumentation is enabled."), llvm::cl::init(false)}; +static llvm::cl::opt clEnableQuantizedMatmulReassociation( + "iree-llvmcpu-enable-quantized-matmul-reassociation", + llvm::cl::desc( + "Enables LLVMCPU codegen optimizations specific to reassociated " + "quantized matmuls (experimental)."), + llvm::cl::init(false)); + static void addTileAndDistributePasses(OpPassManager &pm) { pm.addPass(createTileAndDistributeToWorkgroupsPass()); auto &nestedModulePM = pm.nest(); @@ -163,15 +170,17 @@ LogicalResult verifyDoubleTilingExpertPassPipelineConfig( } } - SmallVector thirdLevelTileSizes; - std::tie(thirdLevelTileSizes, std::ignore) = - tilingConfig.getVectorReductionSizes(); - for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { - if (tileSize != 0 && pLoopsSet.contains(index)) { - return op->emitOpError( - "expected only reduction dims to be set in the third tiling " - "level, got ") - << index << "-th tile size set"; + if (!clEnableQuantizedMatmulReassociation) { + SmallVector thirdLevelTileSizes; + std::tie(thirdLevelTileSizes, std::ignore) = + tilingConfig.getVectorReductionSizes(); + for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { + if (tileSize != 0 && pLoopsSet.contains(index)) { + return op->emitOpError("expected only reduction dims to be set in " + "the third tiling " + "level, got ") + << index << "-th tile size set"; + } } } } @@ -348,7 +357,9 @@ void addMultiTilingExpertPassPipeline( // Run SplitReductionPass before the final reduction Fuse pass, because // SplitReductionPass takes care of banked-tiling. nestedModulePM.addNestedPass( - createLLVMCPUSplitReductionPass(clEnableReassociateFpReductions)); + createLLVMCPUSplitReductionPass( + clEnableReassociateFpReductions, + clEnableQuantizedMatmulReassociation)); nestedModulePM.addNestedPass(createLLVMCPUTilePass(i)); continue; } @@ -385,11 +396,17 @@ void addMultiTilingExpertPassPipeline( // Run IREE specific passes before vector lowering expert. nestedModulePM.addNestedPass( createRemoveSingleIterationLoopPass()); + if (clEnableQuantizedMatmulReassociation) { + nestedModulePM.addNestedPass( + createLLVMCPUFoldVectorContractUnitDimsPass()); + } { LLVMCPUVectorLoweringPassOptions options; options.lowerVectorTransposeToAVX2 = lowerToAVX2; options.splitVectorTransfersTo = "linalg-copy"; + options.enableQuantizedMatmulReassociation = + clEnableQuantizedMatmulReassociation; nestedModulePM.addNestedPass( createLLVMCPUVectorLoweringPass(options)); } @@ -649,6 +666,9 @@ static void addLowerToLLVMPasses(OpPassManager &passManager, passManager.addNestedPass(arith::createArithExpandOpsPass()); passManager.addNestedPass(memref::createExpandOpsPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); + if (clEnableQuantizedMatmulReassociation) { + passManager.addPass(createLLVMCPUFoldMemRefAliasOpsPass()); + } passManager.addPass(createEmulateNarrowTypePass()); passManager.addPass(createCanonicalizerPass()); passManager.addPass(createCSEPass()); @@ -670,8 +690,10 @@ static void addLowerToLLVMPasses(OpPassManager &passManager, void buildLLVMCPUCodegenConfigurationPassPipeline(OpPassManager &passManager) { { - addCommonTargetExecutablePreprocessingPasses(passManager); OpPassManager &modulePassManager = passManager.nest(); + modulePassManager.addNestedPass( + createSetSpecialTilingConfigsPass()); + addCommonTargetExecutablePreprocessingPasses(passManager); modulePassManager.addNestedPass( createRematerializeParallelOpsPass()); // TODO(#13888): This(createExpandF16OpToF32Pass()) pass is being added way diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 75b82de33a1d..8a370c510b5b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -19,6 +19,10 @@ namespace mlir::iree_compiler { class TilingConfig; +// Pass to breakdown subbyte extui +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass(); + /// Performs the final conversion to LLVM dialect. std::unique_ptr> createConvertToLLVMPass(bool reassociateFpReordering = false); @@ -55,8 +59,9 @@ createLLVMCPUMmt4dVectorLoweringPass(); std::unique_ptr> createLLVMCPUPeelPass(); /// Pass to perform SplitReduction transformations of `LinalgOp`s. -std::unique_ptr> -createLLVMCPUSplitReductionPass(bool enableReassociateFpReductions = false); +std::unique_ptr> createLLVMCPUSplitReductionPass( + bool enableReassociateFpReductions = false, + bool enableQuantizedMatmulReassociation = false); /// Synchronizes LLVM linkage with MLIR symbol visibility. std::unique_ptr> @@ -82,6 +87,7 @@ std::unique_ptr> createLLVMCPUUnfuseFMAOpsPass(); struct LLVMCPUVectorLoweringPassOptions { std::string splitVectorTransfersTo = ""; bool lowerVectorTransposeToAVX2 = false; + bool enableQuantizedMatmulReassociation = false; }; std::unique_ptr> createLLVMCPUVectorLoweringPass(); std::unique_ptr> createLLVMCPUVectorLoweringPass( @@ -96,6 +102,14 @@ createVectorContractCustomKernelsPass(); std::unique_ptr> createVerifyLinalgTransformLegalityPass(); +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass(); + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass(); + +std::unique_ptr> +createSetSpecialTilingConfigsPass(); + //------------------------------------------------------------------------------ // LLVMCPU Codegen specific patterns. //------------------------------------------------------------------------------ @@ -108,6 +122,11 @@ void populateUnfusedFMAOpsPassPatterns(MLIRContext *context, void populateVectorContractCustomKernelsPatterns( IREE::HAL::ExecutableTargetAttr target, RewritePatternSet &patterns); +void populateLLVMCPUBreakDownSubbyteExtendPatterns(RewritePatternSet &patterns); + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context); + //----------------------------------------------------------------------------// // LLVMCPU backend Pass Pipelines. //----------------------------------------------------------------------------// diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index 1491a3ba03be..5f9a81baec00 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -45,6 +45,11 @@ def LLVMCPUAssignImportOrdinals : let constructor = "mlir::iree_compiler::createLLVMCPUAssignImportOrdinalsPass()"; } +def LLVMCPUBreakDownSubbyteExtend : Pass<"iree-llvmcpu-breakdown-subbyte-extend", "func::FuncOp"> { + let summary = "Pass to break down subbyte extui ops."; + let constructor = "mlir::iree_compiler::createLLVMCPUBreakDownSubbyteExtendPass()"; +} + def LLVMCPUCheckIRBeforeLLVMConversion : Pass<"iree-llvmcpu-check-ir-before-llvm-conversion", "ModuleOp"> { let summary = "Checks CPU backend specific IR constraints (like no allocas)"; @@ -58,6 +63,20 @@ def LLVMCPUEmitVectorizationRemarks : "mlir::iree_compiler::createLLVMCPUEmitVectorizationRemarksPass()"; } +def LLVMCPUFoldVectorContractUnitDims : + Pass<"iree-llvmcpu-fold-vector-contract-unit-dims", "func::FuncOp"> { + let summary = "Fold unit dims on vector.contract ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldVectorContractUnitDimsPass()"; +} + +def LLVMCPUFoldMemRefAliasOps : + Pass<"iree-llvmcpu-fold-memref-alias-ops", ""> { + let summary = "Fold combinations of memref ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldMemRefAliasOpsPass()"; +} + def LLVMCPULinkExecutables : Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> { let summary = "Links LLVMCPU HAL executables within the top-level program module."; @@ -103,6 +122,9 @@ def LLVMCPUSplitReduction : Pass<"iree-llvmcpu-split-reduction", "func::FuncOp"> Option<"enableFpReductionReordering", "enable-fp-reduction-reordering", "bool", /*default=*/"false", "Flag to enable reduction reordering on floating points.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", + "bool", /*default=*/"false", + "Flag to enable optimizations for reassociated quantized matmuls.">, ]; } @@ -162,6 +184,9 @@ def LLVMCPUVectorLowering : Option<"lowerVectorTransposeToAVX2", "lower-vector-transpose-to-avx2", "bool", /*default=*/"false", "Add specific transpose to avx2 lowering patterns.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool", + /*default=*/"false", + "Add specific patterns for optimizing reassociated quantized matmuls.">, ]; let constructor = "mlir::iree_compiler::createLLVMCPUVectorLoweringPass()"; @@ -173,6 +198,12 @@ def VectorContractCustomKernels : let constructor = "mlir::iree_compiler::createVectorContractCustomKernelsPass()"; } +def SetSpecialTilingConfigs : + Pass<"iree-llvmcpu-set-special-tiling-configs", "func::FuncOp"> { + let summary = "Set the tile sizes for special cases before KernelDispatch."; + let constructor = "mlir::iree_compiler::createSetSpecialTilingConfigsPass()"; +} + def VerifyLinalgTransformLegality : Pass<"iree-llvmcpu-verify-linalg-transform-legality", "ModuleOp"> { let summary = "Verify that only supported IR constructs are passed to the compiler."; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp new file mode 100644 index 000000000000..0b016bd947eb --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/SetSpecialTilingConfigs.cpp @@ -0,0 +1,341 @@ +#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-set-special-tiling-configs" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { +namespace { + +static void setTileSizes(linalg::GenericOp intMatmul, + linalg::GenericOp reassociation, + func::FuncOp entryPointFn, + IREE::HAL::ExecutableTargetAttr target) { + int mDistSize = 1; + int nDistSize = 128; + int mSize = 1; + int nSize = 4; + int kSize = 8; + int groupSize = 1; + SmallVector mDims; + SmallVector nDims; + SmallVector kDims; + SmallVector groupDims; + SmallVector maps = intMatmul.getIndexingMapsArray(); + int lhs = 0; + int rhs = 1; + int out = 2; + auto hasDim = [&](int mapIdx, int dimIdx) -> bool { + return llvm::any_of(maps[mapIdx].getResults(), [&](AffineExpr res) { + auto expr = llvm::dyn_cast(res); + return expr && expr.getPosition() == dimIdx; + }); + }; + for (int dim = 0; dim < intMatmul.getNumLoops(); dim++) { + if (hasDim(lhs, dim) && hasDim(rhs, dim) && hasDim(out, dim)) { + groupDims.push_back(dim); + } else if (hasDim(lhs, dim) && hasDim(rhs, dim) && !hasDim(out, dim)) { + kDims.push_back(dim); + } else if (hasDim(lhs, dim) && !hasDim(rhs, dim) && hasDim(out, dim)) { + mDims.push_back(dim); + } else if (!hasDim(lhs, dim) && hasDim(rhs, dim) && hasDim(out, dim)) { + nDims.push_back(dim); + } + } + if (hasFeature(target, "+avx512bw") || hasFeature(target, "+avx512vnni")) { + kSize = 16; + } + + if (mDims.size() > 1 || nDims.size() > 1 || kDims.size() != 1 || + kDims[0] != intMatmul.getNumLoops() - 1) { + return; + } + + SmallVector distTileSizes_mm(intMatmul.getNumLoops(), 0); + SmallVector parallelTileSizes_mm(intMatmul.getNumLoops(), 0); + SmallVector reductionTileSizes_mm(intMatmul.getNumLoops(), 0); + SmallVector lastTileSizes_mm(intMatmul.getNumLoops(), 0); + + SmallVector distTileSizes_re(reassociation.getNumLoops(), 0); + SmallVector parallelTileSizes_re(reassociation.getNumLoops(), 0); + SmallVector reductionTileSizes_re(reassociation.getNumLoops(), 0); + SmallVector lastTileSizes_re(reassociation.getNumLoops(), 0); + + for (int mDim : mDims) { + distTileSizes_mm[mDim] = mDistSize; + parallelTileSizes_mm[mDim] = mSize; + reductionTileSizes_mm[mDim] = mSize; + + distTileSizes_re[mDim] = mDistSize; + parallelTileSizes_re[mDim] = mSize; + } + for (int nDim : nDims) { + distTileSizes_mm[nDim] = nDistSize; + parallelTileSizes_mm[nDim] = nSize; + reductionTileSizes_mm[nDim] = nSize; + + distTileSizes_re[nDim] = nDistSize; + parallelTileSizes_re[nDim] = nSize; + } + for (int kDim : kDims) { + reductionTileSizes_mm[kDim] = kSize; + } + for (int groupDim : groupDims) { + reductionTileSizes_mm[groupDim] = groupSize; + } + + TileSizesListType tileSizes_mm; + tileSizes_mm.push_back(distTileSizes_mm); + tileSizes_mm.push_back(parallelTileSizes_mm); + tileSizes_mm.push_back(reductionTileSizes_mm); + tileSizes_mm.push_back(lastTileSizes_mm); + + TileSizesListType tileSizes_re; + tileSizes_re.push_back(distTileSizes_re); + tileSizes_re.push_back(parallelTileSizes_re); + tileSizes_re.push_back(reductionTileSizes_re); + tileSizes_re.push_back(lastTileSizes_re); + + IREE::Codegen::DispatchLoweringPassPipeline passPipeline = + IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert; + + MLIRContext *context = entryPointFn.getContext(); + auto config_mm = + IREE::Codegen::LoweringConfigAttr::get(context, tileSizes_mm); + intMatmul->setAttr("lowering_config", config_mm); + + auto config_re = + IREE::Codegen::LoweringConfigAttr::get(context, tileSizes_re); + auto translationInfo_re = IREE::Codegen::TranslationInfoAttr::get( + entryPointFn.getContext(), passPipeline, 0, 1); + auto compilationInfo_re = IREE::Codegen::CompilationInfoAttr::get( + context, config_re, translationInfo_re, ArrayRef({}), + std::nullopt); + + reassociation->setAttr("compilation_info", compilationInfo_re); + + return; +} + +static bool isIntegerMatmul(linalg::GenericOp genericOp) { + if (genericOp.getNumDpsInits() != 1) { + LDBG("Wrong number of outputs for matmul: " << genericOp.getNumDpsInits() + << "\n"); + return false; + } + if (genericOp.getNumDpsInputs() != 2) { + LDBG("Wrong number of inputs for matmul: " << genericOp.getNumDpsInputs() + << "\n"); + return false; + } + + unsigned numLoops = genericOp.getNumLoops(); + unsigned numReductionLoops = genericOp.getNumReductionLoops(); + if (numLoops != 3) { + LDBG("Wrong number of loops for matmul: " << numLoops << "\n"); + return false; + } + if (numReductionLoops != 1) { + LDBG("Wrong number of reduction loops for matmul: " << numReductionLoops + << "\n"); + return false; + } + // Work back from linalg.yield and check body of genericOp. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + Value producerOutput; + Operation *producer; + Operation *mulRhsProducer; + + // Producer of linalg.yield op is arith.addi + { + producerOutput = yieldOp->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.addi op is arith.muli + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.muli op RHS is arith.extui + { + producerOutput = producer->getOperand(1); + mulRhsProducer = producerOutput.getDefiningOp(); + if (!mulRhsProducer || mulRhsProducer->getNumOperands() == 0) + return false; + if (!matchPattern(mulRhsProducer, m_Op())) + return false; + } + + // Producer of arith.subf op LHS is arith.extsi + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + return true; +} + +static bool isReassociatedDequantizationOp(linalg::GenericOp genericOp) { + if (genericOp.getNumDpsInits() != 1) { + LDBG("Wrong number of outputs: " << genericOp.getNumDpsInits() << "\n"); + return false; + } + if (genericOp.getNumDpsInputs() != 5) { + LDBG("Wrong number of inputs: " << genericOp.getNumDpsInputs() << "\n"); + return false; + } + + unsigned numLoops = genericOp.getNumLoops(); + unsigned numReductionLoops = genericOp.getNumReductionLoops(); + if (numLoops != 2) { + LDBG("Wrong number of loops: " << numLoops << "\n"); + return false; + } + if (numReductionLoops != 1) { + LDBG("Wrong number of reduction loops: " << numReductionLoops << "\n"); + return false; + } + // Work back from linalg.yield and check body of genericOp. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + Value producerOutput; + Operation *producer; + Operation *subRhsProducer; + + // Producer of linalg.yield op is arith.addf + { + producerOutput = yieldOp->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.addf op is arith.subf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.subf op RHS is arith.mulf + { + producerOutput = producer->getOperand(1); + subRhsProducer = producerOutput.getDefiningOp(); + if (!subRhsProducer || subRhsProducer->getNumOperands() == 0) + return false; + if (!matchPattern(subRhsProducer, m_Op())) + return false; + } + + // Producer of arith.mulf from arith.subf RHS is arith.mulf + { + producerOutput = subRhsProducer->getOperand(0); + subRhsProducer = producerOutput.getDefiningOp(); + if (!subRhsProducer || subRhsProducer->getNumOperands() == 0) + return false; + if (!matchPattern(subRhsProducer, m_Op())) + return false; + } + + // Producer of arith.subf op LHS is arith.mulf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.mulf op is arith.mulf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.mulf op is arith.sitofp + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + return true; +} + +struct SetSpecialTilingConfigsPass + : public SetSpecialTilingConfigsBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + std::optional> + reassociatedQuantizedMatmulOps = std::nullopt; + for (auto genericOp : + funcOp.getFunctionBody().getOps()) { + if (isReassociatedDequantizationOp(genericOp)) { + auto intMatmulOp = + genericOp.getInputs()[0].getDefiningOp(); + if (intMatmulOp) { + if (isIntegerMatmul(intMatmulOp)) { + reassociatedQuantizedMatmulOps = + std::make_pair(intMatmulOp, genericOp); + break; + } + } + } + } + if (reassociatedQuantizedMatmulOps) { + setTileSizes(reassociatedQuantizedMatmulOps->first, + reassociatedQuantizedMatmulOps->second, funcOp, target); + } + } +}; +} // namespace + +std::unique_ptr> +createSetSpecialTilingConfigsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index 9c22ca1078fd..3dca50496496 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -25,6 +25,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-vector-contract-custom-kernels" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir::iree_compiler { namespace { @@ -85,6 +89,73 @@ static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) { return true; } +static bool isVectorTimesMatrixTransposed(vector::ContractionOp contractionOp, + int64_t splitSize) { + // Check that the reduction is additive. + if (contractionOp.getKind() != vector::CombiningKind::ADD) { + return false; + } + // Check that there are 1 parallel and 1 reduction iterators. + unsigned numIters = splitSize ? 3 : 2; + auto iteratorTypes = contractionOp.getIteratorTypes().getValue(); + if (iteratorTypes.size() != numIters) { + return false; + } + SmallVector parallelIterators; + SmallVector reductionIterators; + for (int i = 0; i < numIters; i++) { + if (vector::isParallelIterator(iteratorTypes[i])) { + parallelIterators.push_back(i); + } else if (vector::isReductionIterator(iteratorTypes[i])) { + reductionIterators.push_back(i); + } else { + return false; + } + } + if (parallelIterators.size() != numIters - 1 || + reductionIterators.size() != 1) { + return false; + } + // Give the found iterators some idiomatic names. + const int NIter = parallelIterators[0]; + const int KIter = reductionIterators[0]; + const int SplitIter = splitSize ? parallelIterators[1] : 0; + // Check that there are 3 indexing maps. + auto indexingMaps = contractionOp.getIndexingMapsArray(); + if (indexingMaps.size() != 3) { + return false; + } + // Check that the indexing maps have the expected form. + SmallVector> expectedMapResults; + if (splitSize) { + SmallVector> res = { + {KIter, SplitIter}, {NIter, KIter, SplitIter}, {NIter, SplitIter}}; + expectedMapResults = res; + numIters = 3; + } else { + SmallVector> res = {{KIter}, {NIter, KIter}, {NIter}}; + expectedMapResults = res; + numIters = 2; + } + for (int m = 0; m < 3; ++m) { + auto map = indexingMaps[m]; + auto expectedResults = expectedMapResults[m]; + if (map.getNumDims() != numIters || + map.getNumResults() != expectedResults.size()) { + return false; + } + for (int r = 0; r < expectedResults.size(); ++r) { + int actualMapResult = + llvm::cast(map.getResults()[r]).getPosition(); + if (actualMapResult != expectedMapResults[m][r]) { + return false; + } + } + } + LDBG("passed isVectorTimesMatrixTransposed"); + return true; +} + // Returns true if `contractionOp` is of the form // matrix * transposed_matrix // where matrix is a vector<{mSize}x{kSize}xType>, and @@ -131,6 +202,31 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, return false; } +static bool matchVMT(vector::ContractionOp contractionOp, int64_t mSize, + int64_t kSize, int64_t nSize, int splitSize, + bool *transpose = nullptr) { + if (mSize != 1) { + return false; + } + if (!isVectorTimesMatrixTransposed(contractionOp, splitSize)) { + return false; + } + VectorType lhsType = llvm::cast(contractionOp.getLhs().getType()); + VectorType rhsType = llvm::cast(contractionOp.getRhs().getType()); + auto lhsShape = lhsType.getShape(); + auto rhsShape = rhsType.getShape(); + if (splitSize && (lhsShape[1] != splitSize || rhsShape[2] != splitSize)) { + return false; + } + if (lhsShape[0] != kSize || rhsShape[1] != kSize) { + return false; + } + if (rhsShape[0] == nSize) { + return true; + } + return false; +} + // `promotedResult` is required to be a Vector. // If its VectorType does not have `promotedType` as its element type, or // the operand to the type-promotion op is not `unpromotedType` returns a null @@ -142,8 +238,9 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, // Note that this only looks at the immediately defining operation, so we likely // want to have earlier passes that sink widening operations as far down as // possible, which is probably just good regardless. -static Value getUnpromotedInput(Type unpromotedType, Type promotedType, - Value promotedResult) { +static Value getUnpromotedInput(PatternRewriter &rewriter, Type unpromotedType, + Type promotedType, Value promotedResult, + bool promoteSmallTypes = false) { VectorType promotedResultVectorType = llvm::cast(promotedResult.getType()); if (promotedResultVectorType.getElementType() != promotedType) { @@ -155,13 +252,29 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // TODO: handle promotion of floating point types. Not doing it for now as // it wouldn't be exercised. auto extSIOp = promotedResult.getDefiningOp(); - if (!extSIOp) { + auto extUIOp = promotedResult.getDefiningOp(); + if (!extSIOp && !extUIOp) { return nullptr; } - Value extInput = extSIOp.getIn(); + Value extInput = extSIOp ? extSIOp.getIn() : extUIOp.getIn(); if (llvm::cast(extInput.getType()).getElementType() != unpromotedType) { - return nullptr; + if (promoteSmallTypes) { + VectorType unpromotedVectorType = + VectorType::get(llvm::cast(extInput.getType()).getShape(), + unpromotedType); + return extSIOp + ? rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult() + : rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult(); + } else { + return nullptr; + } } return extInput; } @@ -169,12 +282,28 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // Helper to create a 1D, contiguous slice of a 1D vector. static Value extract1DSlice(PatternRewriter &rewriter, Location loc, VectorType dstVecType, Value input, int position) { - assert(input.getType().cast().getRank() == 1); assert(dstVecType.getRank() == 1); - std::array offsets{position}; - std::array strides{1}; - return rewriter.create( - loc, input, offsets, dstVecType.getShape(), strides); + if (input.getType().cast().getRank() == 1) { + SmallVector offsets({position}); + SmallVector strides({1}); + SmallVector sizes(dstVecType.getShape()); + return rewriter.create(loc, input, offsets, + sizes, strides); + } else { + SmallVector inputShape( + llvm::cast(input.getType()).getShape()); + assert(inputShape.back() == dstVecType.getNumElements()); + std::reverse(inputShape.begin(), inputShape.end()); + int currentPos = position; + SmallVector indices; + for (auto size : inputShape) { + indices.push_back(currentPos % size); + currentPos = currentPos / size; + } + std::reverse(indices.begin(), indices.end()); + return rewriter.create( + loc, input, SmallVector(indices.begin(), indices.end() - 1)); + } } // Helper to extract an element of a 1D vector. @@ -188,8 +317,12 @@ static Value extract(PatternRewriter &rewriter, Location loc, Value input, } // Helper to flatten a N-dimensional vector to a 1D vector. -static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { +static Value flattenImperfectSize(PatternRewriter &rewriter, Location loc, + Value vector, VectorType regVectorType) { VectorType inputVecType = llvm::cast(vector.getType()); + if (regVectorType.getNumElements() == inputVecType.getShape().back()) { + return vector; + } VectorType dstType = VectorType::get(inputVecType.getNumElements(), inputVecType.getElementType()); return rewriter.create(loc, dstType, vector); @@ -206,20 +339,31 @@ static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { // (2) Be explicit about the size of the vectors involved in the kernel's // "calling convention". struct MMTKernel { - enum class ScalarType : int8_t { None, I8, I32, F32 }; + enum class ScalarType : int8_t { None, I4, I8, I16, I32, F32 }; // Element type of the LHS vectors. ScalarType lhsType = ScalarType::None; // Element type of the RHS vectors. ScalarType rhsType = ScalarType::None; // Element type of the Accumulator and output vectors. ScalarType accType = ScalarType::None; + // Optional user defined constrained codes for input and output registers. + // This is useful when the constraint code is not the same for all operands. + std::optional> lhsCode = std::nullopt; + std::optional> rhsCode = std::nullopt; + std::optional> accCode = std::nullopt; + // This flag indicates whether or not to promote inputs that have a smaller + // bitwidth than lhsType, rhsType, or accType, to the appropriate bitwidth + bool promoteSmallTypes = false; // Number of rows of the LHS and Accumulator tile. - int8_t m0 = 0; + int16_t m0 = 0; // Reduction dimension, i.e. number of columns of the LHS. - int8_t k0 = 0; + int16_t k0 = 0; // Number of rows of the RHS (note that the operation being targeted, MMT, // is matrix multiplication with a *transposed* RHS) - int8_t n0 = 0; + int16_t n0 = 0; + // Size of the added parallel dimension when the vector.contract op has been + // split with splitReduction + int16_t split0 = 0; // Number of LHS elements in the type of register to be used for the LHS. // This is > 1 if SIMD registers are to be used. // Note: LHS/RHS/Accumulator may use registers of different sizes. @@ -235,6 +379,8 @@ struct MMTKernel { int8_t rhsRegs = 0; // Number of registers needed to hold the Accumulator. int8_t accRegs = 0; + // Indicates whether to use Intel or AT&T syntax + bool useIntel = false; // If not null, points to the inline asm code template for this kernel. // Register operands for the LHS, RHS and Accumulator are to be referenced as // $(lhs:), $(rhs:), $(acc:) respectively, where i is a decimal @@ -249,9 +395,15 @@ struct MMTKernel { const char *asmClobbers = nullptr; void validate() const { - assert(m0 * k0 == lhsRegSize * lhsRegs); // number of elements of LHS - assert(n0 * k0 == rhsRegSize * rhsRegs); // number of elements of RHS - assert(m0 * n0 == accRegSize * accRegs); // number of elements of Accum + assert(m0 * k0 == lhsRegSize * lhsRegs || + m0 * k0 * split0 == + lhsRegSize * lhsRegs); // number of elements of LHS + assert(n0 * k0 == rhsRegSize * rhsRegs || + n0 * k0 * split0 == + rhsRegSize * rhsRegs); // number of elements of RHS + assert(m0 * n0 == accRegSize * accRegs || + m0 * n0 * split0 == + accRegSize * accRegs); // number of elements of Accum assert(lhsType != ScalarType::None); assert(rhsType != ScalarType::None); assert(accType != ScalarType::None); @@ -673,13 +825,75 @@ MMTKernel MMTKernel_8x1x1_f32f32f32_Aarch64_Baseline_InlineAsm() { return kernel; } +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpdpwssd $(acc:0), $(rhs:0), $(lhs:0) + vpdpwssd $(acc:1), $(rhs:1), $(lhs:0) + vpdpwssd $(acc:2), $(rhs:2), $(lhs:0) + vpdpwssd $(acc:3), $(rhs:3), $(lhs:0) + )ASM"; + kernel.asmClobbers = ""; + return kernel; +} + +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpmaddwd zmm17, $(rhs:0), $(lhs:0) + vpmaddwd zmm18, $(rhs:1), $(lhs:0) + vpmaddwd zmm19, $(rhs:2), $(lhs:0) + vpmaddwd zmm20, $(rhs:3), $(lhs:0) + vpaddw $(acc:0), $(acc:0), zmm17 + vpaddw $(acc:1), $(acc:1), zmm18 + vpaddw $(acc:2), $(acc:2), zmm19 + vpaddw $(acc:3), $(acc:3), zmm20 + )ASM"; + kernel.asmClobbers = "zmm17,zmm18,zmm19,zmm20"; + return kernel; +} + // Constructs the mlir::Type corresponding to a scalar type. Type mlirType(MLIRContext *context, MMTKernel::ScalarType t) { switch (t) { case MMTKernel::ScalarType::None: break; + case MMTKernel::ScalarType::I4: + return IntegerType::get(context, 4, IntegerType::Signless); case MMTKernel::ScalarType::I8: return IntegerType::get(context, 8, IntegerType::Signless); + case MMTKernel::ScalarType::I16: + return IntegerType::get(context, 16, IntegerType::Signless); case MMTKernel::ScalarType::I32: return IntegerType::get(context, 32, IntegerType::Signless); case MMTKernel::ScalarType::F32: @@ -704,7 +918,7 @@ class MMTKernelGenerator { ArrayRef acc) { validateOperands(lhs, rhs, acc); if (kernel.asmImpl) { - return generateAsm(rewriter, loc, lhs, rhs, acc); + return generateAsm(rewriter, loc, lhs, rhs, acc, kernel.useIntel); } // In the future we may have alternate generator paths, e.g. 1D intrinsics // or other asm paths with a different interface, e.g. handling also @@ -754,10 +968,17 @@ class MMTKernelGenerator { validate(acc, kernel.accRegs, getAccRegVectorType()); } // Helper for generateAsmCodeAndConstraints - std::string getConstraintCode() const { + std::string + getConstraintCode(std::optional kernelConstraintCode) const { + if (kernelConstraintCode) { + return std::string(*kernelConstraintCode); + } if (isAArch64(target)) { return "w"; } + if (isX86(target)) { + return "v"; + } assert(false && "what constraint code to use on this arch?"); return {}; } @@ -819,31 +1040,39 @@ class MMTKernelGenerator { // processedIdx is the index of a register in the processed asm. // Example: $5 => processedIdx == 5 int processedIdx = 0; - auto processOperands = [&](Constraints::Kind constraintKind, - const char *name, int count) { - const std::string &constraintCode = getConstraintCode(); - // unprocessedIdx is the index of a register in the unprocessed asm. - // Example: $(lhs:1) => unprocessedIdx == 1 - for (int unprocessedIdx = 0; unprocessedIdx < count; - ++unprocessedIdx, ++processedIdx) { - constraints.add(constraintKind, constraintCode); - // Perform the code replacement for the operand. - // Example: $(lhs:1) => $5 - replaceAllSubstrsInPlace( - code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), - llvm::formatv("${0}", processedIdx)); - } - }; - processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs); - processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs); - processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs); + auto processOperands = + [&](Constraints::Kind constraintKind, const char *name, int count, + std::optional> kernelCodes) { + const std::string &constraintCode = getConstraintCode(std::nullopt); + // unprocessedIdx is the index of a register in the unprocessed asm. + // Example: $(lhs:1) => unprocessedIdx == 1 + for (int unprocessedIdx = 0; unprocessedIdx < count; + ++unprocessedIdx, ++processedIdx) { + if (kernelCodes) { + constraints.add(constraintKind, (*kernelCodes)[unprocessedIdx]); + } else { + constraints.add(constraintKind, constraintCode); + } + // Perform the code replacement for the operand. + // Example: $(lhs:1) => $5 + replaceAllSubstrsInPlace( + code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), + llvm::formatv("${0}", processedIdx)); + } + }; + processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs, + kernel.accCode); + processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs, + kernel.lhsCode); + processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs, + kernel.rhsCode); constraints.setClobbers(kernel.asmClobbers); constraintsString = constraints.toString(); } // Helper for generate(). Implements the asm path. SmallVector generateAsm(PatternRewriter &rewriter, Location loc, ArrayRef lhs, ArrayRef rhs, - ArrayRef acc) { + ArrayRef acc, bool useIntel) { SmallVector inputs; // First the input operands. Then the input-output operands, which, as far // as input constraints are concerned, are *tied* inputs, i.e. refer to @@ -863,9 +1092,13 @@ class MMTKernelGenerator { SmallVector outputOperandTypes( llvm::map_range(acc, [](Value v) { return v.getType(); })); auto returnType = - LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); + outputOperandTypes.size() == 1 + ? outputOperandTypes[0] + : LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); auto dialectAttr = - LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); + useIntel + ? LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_Intel) + : LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); std::string code; std::string constraints; generateAsmCodeAndConstraints(code, constraints); @@ -875,10 +1108,14 @@ class MMTKernelGenerator { /*operand_attrs=*/ArrayAttr()); // Extract result vectors from the asm op. SmallVector resVec; - for (int i = 0; i < kernel.accRegs; ++i) { - SmallVector position = {i}; - resVec.push_back( - rewriter.create(loc, asmOp.getRes(), position)); + if (outputOperandTypes.size() == 1) { + resVec.push_back(asmOp.getRes()); + } else { + for (int i = 0; i < kernel.accRegs; ++i) { + SmallVector position = {i}; + resVec.push_back(rewriter.create( + loc, asmOp.getRes(), position)); + } } return resVec; } @@ -913,7 +1150,9 @@ class MMTCustomKernelPattern : public OpRewritePattern { // Check if `contractionOp` matches, and obtain the (un-promoted) input // LHS and RHS vectors. bool transposeKernel = false; - if (!matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, + if (!matchVMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, kernel.split0, + &transposeKernel) && + !matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, &transposeKernel)) { return failure(); } @@ -928,9 +1167,11 @@ class MMTCustomKernelPattern : public OpRewritePattern { return failure(); } Value unpromotedLhs = - getUnpromotedInput(lhsElemType, accElemType, contractionOp.getLhs()); + getUnpromotedInput(rewriter, lhsElemType, accElemType, + contractionOp.getLhs(), kernel.promoteSmallTypes); Value unpromotedRhs = - getUnpromotedInput(rhsElemType, accElemType, contractionOp.getRhs()); + getUnpromotedInput(rewriter, rhsElemType, accElemType, + contractionOp.getRhs(), kernel.promoteSmallTypes); if (!unpromotedLhs || !unpromotedRhs) { return failure(); } @@ -952,9 +1193,23 @@ class MMTCustomKernelPattern : public OpRewritePattern { // `contractionOp` matches, start rewriting it. Location loc = contractionOp.getLoc(); // Flatten the inputs to 1D vectors. - Value flatLhs = flatten(rewriter, loc, unpromotedLhs); - Value flatRhs = flatten(rewriter, loc, unpromotedRhs); - Value flatAcc = flatten(rewriter, loc, contractionOp.getAcc()); + VectorType lhsRegVectorType = generator.getLhsRegVectorType(); + VectorType rhsRegVectorType = generator.getRhsRegVectorType(); + VectorType accRegVectorType = generator.getAccRegVectorType(); + Value lhs, rhs; + if (transposeKernel) { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, rhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, lhsRegVectorType); + } else { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, lhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, rhsRegVectorType); + } + Value acc = flattenImperfectSize(rewriter, loc, contractionOp.getAcc(), + accRegVectorType); // Slice into SIMD-register-sized 1D input vectors ready to feed to the // target SIMD instructions. auto sliceIntoRegVectors = [&](int regsCount, VectorType regVectorType, @@ -967,17 +1222,14 @@ class MMTCustomKernelPattern : public OpRewritePattern { } return regVectors; }; - VectorType lhsRegVectorType = generator.getLhsRegVectorType(); - VectorType rhsRegVectorType = generator.getRhsRegVectorType(); - VectorType accRegVectorType = generator.getAccRegVectorType(); - Value flatLhsForKernel = transposeKernel ? flatRhs : flatLhs; - Value flatRhsForKernel = transposeKernel ? flatLhs : flatRhs; + Value lhsForKernel = transposeKernel ? rhs : lhs; + Value rhsForKernel = transposeKernel ? lhs : rhs; SmallVector lhsRegVectors = - sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, flatLhsForKernel); + sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, lhsForKernel); SmallVector rhsRegVectors = - sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, flatRhsForKernel); + sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, rhsForKernel); SmallVector accRegVectors = - sliceIntoRegVectors(kernel.accRegs, accRegVectorType, flatAcc); + sliceIntoRegVectors(kernel.accRegs, accRegVectorType, acc); // Generate the kernel! SmallVector resRegVectors = generator.generate( rewriter, loc, lhsRegVectors, rhsRegVectors, accRegVectors); @@ -1036,8 +1288,8 @@ struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics return failure(); } - Value inLhs = getUnpromotedInput(I8Type, I32Type, lhs); - Value inRhs = getUnpromotedInput(I8Type, I32Type, rhs); + Value inLhs = getUnpromotedInput(rewriter, I8Type, I32Type, lhs); + Value inRhs = getUnpromotedInput(rewriter, I8Type, I32Type, rhs); if (!inLhs || !inRhs) return failure(); @@ -1170,6 +1422,15 @@ void populateVectorContractCustomKernelsPatterns( patterns.add( context, MMTKernel_8x8x8_i8i8i32_Aarch64I8mm_InlineAsm()); } + } else if (isX86(target)) { + if (hasFeature(target, "+avx512vnni")) { + patterns.add( + context, + MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm()); + } else if (hasFeature(target, "+avx512bw")) { + patterns.add( + context, MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm()); + } } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ef26385f57ea..577c248cc09e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -22,7 +22,9 @@ #include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -924,6 +926,25 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if ((groupSize / subgroupSize) > subgroupSize) return failure(); + // With just one subgroup per workgroup, make each subgroup do more work and + // process a few reductions along the last parallel dimension. + // TODO: We should also check that this will result in data reuse for at least + // one argument. + // TODO: This is experimental for matvec (matmul_transpose_b) on rocm-only for + // now. + if (numDynamicReductionDims == 0 && numParallelDims == 2 && + isRocmTarget(entryPoint)) { + if (*parallelSize && !parallelDims.empty() && groupSize == subgroupSize) { + int maxParallelFactor = 4; // Keeping this conservative for now. + int64_t lastParallelBound = bounds[parallelDims.back()]; + if (!ShapedType::isDynamic(lastParallelBound) && + (lastParallelBound % maxParallelFactor == 0) && + lastParallelBound > maxParallelFactor) { + workgroupTileSizes.back() = maxParallelFactor; + } + } + } + std::array workgroupSize = {groupSize, 1, 1}; SmallVector reductionTileSizes(op.getNumLoops(), 0); int64_t remainingGroupSize = groupSize; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index e529e8a1017a..c0bbef5614fd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -501,7 +501,7 @@ static void addLowerAndOptimzeAddressComputation(OpPassManager &pm) { pm.addPass(createExtractAddressComputationGPUPass()); pm.addNestedPass(memref::createExpandOpsPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createIREEExpandStridedMetadataPass()); // Hoist loop invariant variables to give decompose affine pass the right loop // dependencies. pm.addPass(createLoopInvariantCodeMotionPass()); @@ -575,7 +575,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool useROCM) { pm.addNestedPass(memref::createExpandOpsPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createIREEExpandStridedMetadataPass()); pm.addPass(createEmulateNarrowTypePass()); pm.addPass(createLowerAffinePass()); pm.addPass(createCanonicalizerPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir index 47b34315160d..2cfa7a8b3aeb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir @@ -50,3 +50,50 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf // CHECK: func.func @dynamic_batch_matvec() // CHECK: linalg.batch_matmul // CHECK-SAME: lowering_config = #[[$CONFIG]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> + +hal.executable @vmt { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>) { + hal.executable.export @vmt layout(#pipeline_layout) + builtin.module { + func.func @vmt() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<1x4096xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32000x4096xf16> + %5 = tensor.empty() : tensor<1x32000xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<1x32000xf16>) -> tensor<1x32000xf16> + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<1x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<1x32000xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %out, %8 : f16 + linalg.yield %9 : f16 + } -> tensor<1x32000xf16> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 32000], strides = [1, 1] : tensor<1x32000xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-LABEL: hal.executable.export public @vmt +// CHECK-SAME: subgroup_size = 64 : index +// CHECK-SAME: translation_info = #[[$TRANSLATION]] +// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] +// CHECK: func.func @vmt() +// CHECK: linalg.generic +// CHECK-SAME: lowering_config = #[[$CONFIG]] diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 5ca39ceb9a18..d094e0864834 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -84,6 +84,7 @@ static llvm::cl::opt clDispatchGenerateWorkloadRegion( "iree-flow-dispatch-generate-workload-region", llvm::cl::desc("Generate the workload region."), llvm::cl::init(true)); + static llvm::cl::opt clNormalizeInputIndexingMap( "iree-flow-normalize-input-indexing-map", llvm::cl::desc("Enable normalizing input indexing map to identity."), diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp index 7ec2388020b4..49d895540b29 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" +#include "iree/compiler/GlobalOptimization/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -776,6 +777,14 @@ static LogicalResult reassociateDequantMatmul(RewriterBase &rewriter, rewriter.replaceOp(matmul, reassociatedDequantization.getResult(0)); + // Fuse dequantization + matmul ops into a single dispatch region + SmallVector dequantMatmulOps{quantizedIntegerMatmul, + reassociatedDequantization}; + FailureOr maybeDequantMatmulDispatch = + wrapConsecutiveOpsInDispatchRegion(rewriter, dequantMatmulOps); + if (failed(maybeDequantMatmulDispatch)) { + return failure(); + } return success(); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 9c1ddf178628..88ce7a1672cf 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -68,7 +68,9 @@ void buildGlobalOptimizationPassPipeline( .addPass(createRemoveZeroExtentTensorsPass) .addPass(createDetachElementwiseFromNamedOpsPass) .addPass(mlir::createLinalgNamedOpConversionPass) - .addPass(createConvert1X1FilterConv2DToMatmulPass); + .addPass(createConvert1X1FilterConv2DToMatmulPass) + .addPredicatedPass(!clEnableQuantizedMatmulReassociation, + createLiftGenericToTransposeBatchMatmulPass); mainPassManager.addPass(createEraseUnusedLinalgOperands()); // Expand tensor shapes into SSA values and optimize the whole program. diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir index fd8d196bcebe..e535c52eafcb 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir @@ -97,6 +97,7 @@ module { // REASSOCIATE-CHECK: %[[INITMATMUL:.+]] = tensor.empty() : tensor<11008x32xi32> // REASSOCIATE-CHECK: %[[FILLMATMUL:.+]] = linalg.fill ins(%[[C0I32]] // REASSOCIATE-CHECK-SAME: outs(%[[INITMATMUL]] : +// REASSOCIATE-CHECK: %[[DISP:.+]] = flow.dispatch.region // REASSOCIATE-CHECK: %[[GENMATMUL:.+]] = linalg.generic // REASSOCIATE-CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]] // REASSOCIATE-CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] @@ -122,4 +123,5 @@ module { // REASSOCIATE-CHECK: %[[RESUBF:.+]] = arith.subf %[[REMULF1]], %[[REMULF3]] : f32 // REASSOCIATE-CHECK: %[[READDF:.+]] = arith.addf %[[RESUBF]], %[[REOUT0]] : f32 // REASSOCIATE-CHECK: linalg.yield %[[READDF]] : f32 -// REASSOCIATE-CHECK: return %[[GENREASSOCIATE]] +// REASSOCIATE-CHECK: flow.return %[[GENREASSOCIATE]] +// REASSOCIATE-CHECK: return %[[DISP]] diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index c9acfa23b3e5..f33d8712093b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -31,6 +31,9 @@ iree_compiler_cc_library( name = "Transforms", srcs = [ "ConvertConv2DToImg2Col.cpp", + "ConvertConvNchwToNhwc.cpp", + "ConvertConvToChannelsLast.cpp", + "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", "PassDetail.h", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index fb7b26ff5dcf..09e61373aed9 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -27,6 +27,9 @@ iree_cc_library( "Passes.h.inc" SRCS "ConvertConv2DToImg2Col.cpp" + "ConvertConvNchwToNhwc.cpp" + "ConvertConvToChannelsLast.cpp" + "GeneralizeConvolutions.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp index 8bdb0120abdf..c06717a68959 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp @@ -21,6 +21,8 @@ namespace mlir::iree_compiler::Preprocessing { +static const char winogradAttr[] = "iree_winograd_conv"; + static bool hasAllOneValues(DenseIntElementsAttr attr) { return llvm::all_of( attr, [](APInt element) { return element.getSExtValue() == 1; }); @@ -94,6 +96,9 @@ class ConvertConv2DNhwcHwcf final if (!hasAllOneValues(convOp.getDilations())) return failure(); + // Ignore if marked as Winograd convolution + if (convOp->hasAttr(winogradAttr)) return failure(); + Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; @@ -403,6 +408,9 @@ class ConvertConv2DNchwFchw final if (!hasAllOneValues(convOp.getDilations())) return failure(); + // Ignore if marked as Winograd convolution + if (convOp->hasAttr(winogradAttr)) return failure(); + Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp new file mode 100644 index 000000000000..fdf60f9e4493 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp @@ -0,0 +1,560 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-flow-convert-conv-nchw-to-nhwc" + +namespace mlir::iree_compiler::Preprocessing { + +using TransposeIndices = SmallVector; + +static const StringLiteral transposeEmptyMarker = "__nchw_to_nhwc_init__"; +static const StringLiteral transposePropagateUpMarker = "__nchw_to_nhwc_up__"; +static const StringLiteral transposePropagateDownMarker = + "__nchw_to_nhwc_down__"; + +static TransposeIndices invertIndices(TransposeIndices targetIndices) { + auto rank = targetIndices.size(); + TransposeIndices inverted(rank); + for (auto i : llvm::enumerate(targetIndices)) { + inverted[i.value()] = i.index(); + } + return inverted; +} + +static TransposeIndices getTransposeIndices(linalg::TransposeOp op) { + return llvm::to_vector(op.getPermutation()); +} + +static bool isStaticallyShaped(Value input) { + if (auto inputType = input.getType().dyn_cast()) + return inputType.hasStaticShape(); + return false; +} + +// Get the transpose indices if the given input comes from a transpose and is +// marked to propagate down. +static std::optional getIndicesFromInput(Value input) { + if (!isStaticallyShaped(input)) return std::nullopt; + auto parent = input.getDefiningOp(); + if (parent && parent->hasAttr(transposePropagateDownMarker)) + return getTransposeIndices(parent); + return std::nullopt; +} + +// Get the transpose indices if the given output is used by at least one +// transpose and that transpose is marked to propagate up. Additionally don't +// propagate if there are conflicting transposes. +static std::optional getIndicesFromOutput(Value output) { + if (!isStaticallyShaped(output)) return std::nullopt; + std::optional transposedOut; + if (llvm::all_of(output.getUses(), [&transposedOut](const OpOperand &use) { + auto owner = dyn_cast(use.getOwner()); + if (owner && owner->hasAttr(transposePropagateUpMarker)) { + if (transposedOut.has_value()) { + if (getTransposeIndices(transposedOut.value()) == + getTransposeIndices(owner)) + return true; + return false; + } + transposedOut = owner; + return true; + } + return false; + })) { + if (transposedOut.has_value()) + return getTransposeIndices(transposedOut.value()); + } + return std::nullopt; +} + +// Helper to shuffle vectors according to the transpose indices. +template +static SmallVector shuffleFromIndices(SmallVector unshuffled, + TransposeIndices targetIndices) { + auto rank = unshuffled.size(); + assert(targetIndices.size() == rank && + "Mismatch between number of elements in input and number of indices"); + SmallVector shuffled(rank); + + for (auto i : llvm::enumerate(targetIndices)) { + shuffled[i.index()] = unshuffled[i.value()]; + } + return shuffled; +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static Value createTranspose(PatternRewriter &rewriter, Location loc, + Value input, TransposeIndices targetIndices, + bool propagateUp) { + RankedTensorType inType = input.getType().cast(); + auto elementType = inType.getElementType(); + auto inputShape(inType.getShape()); + + auto outputShape = + shuffleFromIndices(llvm::to_vector(inputShape), targetIndices); + + Value output = + rewriter.create(loc, outputShape, elementType); + output.getDefiningOp()->setAttr(transposeEmptyMarker, rewriter.getUnitAttr()); + + auto transpose = + rewriter.create(loc, input, output, targetIndices); + transpose->setAttr( + propagateUp ? transposePropagateUpMarker : transposePropagateDownMarker, + rewriter.getUnitAttr()); + return transpose.getResults()[0]; +} + +// Supports conv and pooling ops, where pooling ops don't transpose the filter. +template +static LogicalResult convertConvLikeNchwToNhwc(PatternRewriter &rewriter, + ConvOpTy convOp, + bool transposeFilter) { + LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n"); + + Location loc = convOp.getLoc(); + + Value input = convOp.image(); + Value filter = convOp.filter(); + Value output = convOp.getOutputs()[0]; + + if (!isStaticallyShaped(input) || !isStaticallyShaped(output) || + (transposeFilter && !isStaticallyShaped(filter))) { + return failure(); + } + + TransposeIndices NCHWIndices = {0, 2, 3, 1}; + + auto transposedInput = + createTranspose(rewriter, loc, input, NCHWIndices, true); + auto transposedFilter = filter; + if (transposeFilter) { + TransposeIndices FCHWIndices = {2, 3, 1, 0}; + transposedFilter = + createTranspose(rewriter, loc, filter, FCHWIndices, true); + } + auto transposedOutput = + createTranspose(rewriter, loc, output, NCHWIndices, true); + + auto conv = + rewriter + .create(loc, transposedOutput.getType(), + ValueRange{transposedInput, transposedFilter}, + transposedOutput, convOp.getStrides(), + convOp.getDilations()) + .getResult(0); + + auto returnToNCHW = + createTranspose(rewriter, loc, conv, invertIndices(NCHWIndices), false); + + rewriter.replaceOp(convOp, returnToNCHW); + return success(); +} + +namespace { + +/* + * Convolution conversion patterns + */ + +struct ConvertLinalgConvNchwFchw : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, convOp, + true); + } +}; + +struct ConvertLinalgPoolingNchwMax + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::PoolingNchwMaxOp poolOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, poolOp, + false); + } +}; + +struct ConvertLinalgPoolingNchwSum + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::PoolingNchwSumOp poolOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, poolOp, + false); + } +}; + +/* + * Transpose propagation patterns + */ + +struct PropagateThroughTensorPadPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughTensorPadPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + if (propagateUp) { + auto indices = getIndicesFromOutput(padOp.getResult()); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + auto indices = getIndicesFromInput(padOp.getSource()); + if (!indices.has_value()) return failure(); + transposeIndices = invertIndices(indices.value()); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << padOp << "\n"); + + Location loc = padOp.getLoc(); + + auto input = padOp.getSource(); + SmallVector mixedLow = shuffleFromIndices( + padOp.getMixedLowPad(), transposeIndices); + SmallVector mixedHigh = shuffleFromIndices( + padOp.getMixedHighPad(), transposeIndices); + + auto transposedInput = + createTranspose(rewriter, loc, input, transposeIndices, true); + + SmallVector outputShape(padOp.getResultType().getShape()); + SmallVector transposedOutputShape = + shuffleFromIndices(outputShape, transposeIndices); + RankedTensorType transposedOutputType = RankedTensorType::get( + transposedOutputShape, padOp.getResultType().getElementType()); + + auto newPad = rewriter.create(loc, transposedOutputType, + transposedInput, mixedLow, + mixedHigh, padOp.getNofold()); + IRMapping mapper; + padOp.getRegion().cloneInto(&newPad.getRegion(), mapper); + + auto returnToNCHW = createTranspose(rewriter, loc, newPad.getResult(), + invertIndices(transposeIndices), false); + + rewriter.replaceOp(padOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughLinalgFillPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughLinalgFillPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(linalg::FillOp fillOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + if (propagateUp) { + auto indices = getIndicesFromOutput(fillOp.getResult(0)); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + auto indices = getIndicesFromInput(fillOp.value()); + if (!indices.has_value()) return failure(); + transposeIndices = invertIndices(indices.value()); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << fillOp << "\n"); + Location loc = fillOp.getLoc(); + + auto transposedOutput = + createTranspose(rewriter, loc, fillOp.output(), transposeIndices, true); + + auto newTensor = + rewriter.create(loc, fillOp.value(), transposedOutput) + .getResult(0); + + auto returnToNCHW = createTranspose(rewriter, loc, newTensor, + invertIndices(transposeIndices), false); + + rewriter.replaceOp(fillOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughLinalgGenericPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughLinalgGenericPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), + propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + // For now restrict to single results. + if (genericOp.getNumResults() != 1) return failure(); + + if (propagateUp) { + auto indices = getIndicesFromOutput(genericOp.getOutputs()[0]); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + // TODO: Enable directly fusing the transpose with the inputs. + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << genericOp << "\n"); + + Location loc = genericOp.getLoc(); + + auto transposedOutput = genericOp.getOutputs()[0]; + auto indexingMaps = genericOp.getIndexingMapsArray(); + + if (propagateUp) { + transposedOutput = createTranspose(rewriter, loc, transposedOutput, + transposeIndices, true); + + AffineMap outMap = indexingMaps.back(); + SmallVector outExprs(outMap.getResults()); + SmallVector exprs = + shuffleFromIndices(outExprs, transposeIndices); + indexingMaps[indexingMaps.size() - 1] = + AffineMap::get(outMap.getNumDims(), outMap.getNumSymbols(), exprs, + genericOp->getContext()); + } + + SmallVector newInputs; + for (auto input : llvm::enumerate(genericOp.getInputs())) { + newInputs.push_back(input.value()); + } + + SmallVector iteratorTypes = + genericOp.getIteratorTypesArray(); + + auto newGeneric = rewriter.create( + loc, transposedOutput.getType().cast(), newInputs, + transposedOutput, indexingMaps, iteratorTypes); + IRMapping mapper; + genericOp.getRegion().cloneInto(&newGeneric.getRegion(), mapper); + + Value returnToNCHW = newGeneric.getResult(0); + if (propagateUp) { + returnToNCHW = createTranspose(rewriter, loc, returnToNCHW, + invertIndices(transposeIndices), false); + } + + rewriter.replaceOp(genericOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughTensorEmptyPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::EmptyOp emptyOp, + PatternRewriter &rewriter) const override { + if (emptyOp->hasAttr(transposeEmptyMarker)) return failure(); + TransposeIndices transposeIndices; + + auto indices = getIndicesFromOutput(emptyOp.getResult()); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + + LLVM_DEBUG(llvm::dbgs() << "propagating " << emptyOp << "\n"); + + Location loc = emptyOp.getLoc(); + + SmallVector mixedSizes = shuffleFromIndices( + emptyOp.getMixedSizes(), transposeIndices); + + auto newTensor = rewriter.create( + loc, mixedSizes, emptyOp.getType().getElementType()); + auto returnToNCHW = createTranspose(rewriter, loc, newTensor.getResult(), + invertIndices(transposeIndices), false); + + rewriter.replaceOp(emptyOp, returnToNCHW); + return success(); + } +}; + +/* + * Folding away cancelling transposes and generalizing + */ + +// Cancel if this transpose is tagged with a propagating tag and the defining op +// for the input is the inverse of this transpose +struct CancelNCHWToNHWCTransposePattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto transposeIndices = invertIndices(getTransposeIndices(transposeOp)); + + auto parentOp = + transposeOp->getOperand(0).getDefiningOp(); + if (parentOp) { + if (getTransposeIndices(parentOp) == transposeIndices) { + rewriter.replaceOp(transposeOp, parentOp->getOperand(0)); + return success(); + } + } + + return failure(); + } +}; + +struct GeneralizeTransposeOpPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + if (transposeOp->hasAttr(transposePropagateUpMarker) || + transposeOp->hasAttr(transposePropagateDownMarker)) { + auto context = rewriter.getContext(); + auto rank = + transposeOp.getResultTypes()[0].cast().getRank(); + + auto transposeIndices = getTransposeIndices(transposeOp); + + SmallVector idExprs; + for (auto i = 0; i < rank; i++) + idExprs.push_back(getAffineDimExpr(i, context)); + + SmallVector swapExprs = + shuffleFromIndices(idExprs, transposeIndices); + + SmallVector indexingMaps = { + AffineMap::get(rank, 0, idExprs, context), + AffineMap::get(rank, 0, swapExprs, context)}; + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultTypes()[0], + transposeOp.getOperand(0), transposeOp.getOperand(1), indexingMaps, + iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); + return success(); + } + return failure(); + } +}; + +// The high level strategy for this pass is as follows: +// 1. Do the conversions for all conv_nchw_fchw ops (and pooling ops) and +// wrap the converted convolutions in transposes. Each transpose is tagged +// to indicate which direction the transpose should propagate through the +// graph. +// 2. Traverse the ops in the function in reverse to propagate transposes +// marked for upwards propagation to their parents. Ideally just before ops +// such as arith.constant or function arguments. +// 3. Propagate the transposes marked for downward propagation to its users, +// ideally to just before return. +// 4. Canonicalize out all adjacent cancelling transposes and generalize the +// remaining transposes to allow for fusing them with nearby ops. +struct ConvertConvNchwToNhwcPass + : public ConvertConvNchwToNhwcBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + Operation *funcOp = getOperation(); + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Propagate transposes up the graph. + { + SmallVector ops; + funcOp->walk([&](Operation *op) { ops.push_back(op); }); + + RewritePatternSet patterns(context); + patterns.insert(context, true); + patterns.insert(context); + patterns.insert(context, true); + patterns.insert(context, true); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + SmallVector reverseOps(llvm::reverse(ops)); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::AnyOp; + (void)applyOpPatternsAndFold(reverseOps, frozenPatterns, config); + } + + // Propagate transposes down the graph. + { + RewritePatternSet patterns(context); + patterns.insert(context, false); + patterns.insert(context, false); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + // Cancel out transposes. + { + RewritePatternSet patterns(context); + patterns.insert(context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + // Generalize remaining transposes to allow fusion with other ops. + { + RewritePatternSet patterns(context); + patterns.insert(context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + } +}; + +} // namespace + +std::unique_ptr> +createConvertConvNchwToNhwcPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp new file mode 100644 index 000000000000..6ee82e170b3f --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp @@ -0,0 +1,886 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-preprocessing-convert-conv-to-channels-last" + +namespace mlir::iree_compiler::Preprocessing { + +static const StringLiteral fullTileTransposeMarker = "__fully_transpose_tile__"; + +using TransposeIndices = SmallVector; +using ConvBuilderFn = std::function newDimOrder, + SmallVector newIteratorTypes)>; +using linalg::detail::MatchConvolutionResult; + +static Value defaultConvBuilderFn( + OpBuilder &b, Location loc, linalg::LinalgOp srcConv, Value input, + Value filter, Value output, AffineMap inputMap, AffineMap filterMap, + AffineMap outputMap, SmallVector newDimOrder, + SmallVector newIteratorTypes) { + AffineMap newInputMap = inputMap; + AffineMap newFilterMap = filterMap; + AffineMap newOutputMap = outputMap; + if (!newDimOrder.empty()) { + DenseMap dimMap; + for (auto [newDim, oldDim] : llvm::enumerate(newDimOrder)) + dimMap[b.getAffineDimExpr(oldDim)] = b.getAffineDimExpr(newDim); + newInputMap = inputMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + newFilterMap = filterMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + newOutputMap = outputMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + } + SmallVector iterators = srcConv.getIteratorTypesArray(); + iterators.append(newIteratorTypes); + auto genericConv = b.create( + loc, output.getType(), ValueRange{input, filter}, output, + ArrayRef{newInputMap, newFilterMap, newOutputMap}, iterators); + IRMapping mapper; + srcConv->getRegion(0).cloneInto(&genericConv.getRegion(), mapper); + return genericConv.getResult(0); +} + +template +static Value namedConvBuilderFn( + OpBuilder &b, Location loc, linalg::LinalgOp srcConv, Value input, + Value filter, Value output, AffineMap inputMap, AffineMap filterMap, + AffineMap outputMap, SmallVector newDimOrder, + SmallVector newIteratorTypes) { + sourceNamedConvTy namedConv = cast(srcConv); + return b + .create( + loc, output.getType(), ValueRange{input, filter}, output, + namedConv.getStrides(), namedConv.getDilations()) + .getResult(0); +} + +static TransposeIndices getNormalizedIndices(TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + TransposeIndices normalized(targetIndices.size()); + for (auto i : llvm::enumerate(targetIndices)) + normalized[i.index()] = i.value() - startDim; + return normalized; +} + +static TransposeIndices invertIndices(TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + TransposeIndices inverted(targetIndices.size()); + for (auto i : llvm::enumerate(targetIndices)) { + inverted[i.value() - startDim] = i.index() + startDim; + } + return inverted; +} + +static bool isInnerIdentityIndices(TransposeIndices indices, int64_t rank) { + return indices.empty() || + (llvm::all_of(llvm::enumerate(indices), + [indices](auto e) { + if (e.index() == 0) return true; + return indices[e.index() - 1] < e.value(); + }) && + indices.back() == rank - 1); +} + +// Helper to shuffle vectors according to the transpose indices. +template +static SmallVector shuffleFromIndices(SmallVector unshuffled, + TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + SmallVector shuffled(unshuffled); + for (auto i : llvm::enumerate(targetIndices)) { + shuffled[i.index() + startDim] = unshuffled[i.value()]; + } + return shuffled; +} + +template +static SmallVector getPackedVector(SmallVector vec, + TransposeIndices targetIndices) { + SmallVector packedShape; + for (auto [i, val] : llvm::enumerate(vec)) + if (!llvm::is_contained(targetIndices, i)) packedShape.push_back(val); + for (auto i : targetIndices) packedShape.push_back(vec[i]); + return packedShape; +} + +static SmallVector getUntiledPackReassociationMap( + TransposeIndices targetIndices, int64_t rank) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + int dimCount = targetIndices.size(); + SmallVector reassociationMap; + for (int i = 0; i <= startDim; i++) reassociationMap.push_back({i}); + for (int i = startDim + 1; i < dimCount + startDim + 1; i++) + reassociationMap[startDim].push_back(i); + for (int i = dimCount + startDim + 1; i < dimCount + rank; i++) + reassociationMap.push_back({i}); + return reassociationMap; +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static std::tuple, AffineMap> +createTransposeAsTensorPack( + PatternRewriter &rewriter, Location loc, Value input, AffineMap inputMap, + TransposeIndices targetIndices, int tilingFactor, + llvm::DenseMap innerDimToDomainDim) { + if (isInnerIdentityIndices(targetIndices, inputMap.getNumResults())) + return std::make_tuple(input, std::nullopt, inputMap); + + RankedTensorType inType = input.getType().cast(); + auto elementType = inType.getElementType(); + auto inputShape(inType.getShape()); + + SmallVector transposedTileSizes( + targetIndices.size(), rewriter.getIndexAttr(tilingFactor)); + if (tilingFactor <= 0) { + for (auto [index, i] : llvm::enumerate(targetIndices)) { + if (ShapedType::isDynamic(inputShape[i])) + transposedTileSizes[index] = + rewriter.create(loc, input, i).getResult(); + else + transposedTileSizes[index] = rewriter.getIndexAttr(inputShape[i]); + } + } + + // Pack the input tensor. + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, input, transposedTileSizes, targetIndices, + SmallVector{}); + auto packedInput = rewriter.create( + loc, input, empty, targetIndices, transposedTileSizes, + /*padding=*/std::nullopt, SmallVector{}); + + SmallVector mapResults(inputMap.getResults()); + AffineMap transposedMap; + + Value packedOperand = packedInput; + // Collapse the unit dims created by tensor.pack. + if (tilingFactor <= 0) { + auto reassociationMap = + getUntiledPackReassociationMap(targetIndices, inType.getRank()); + auto transposedInputShape = + getPackedVector(llvm::to_vector(inputShape), targetIndices); + packedOperand = + rewriter + .create( + loc, RankedTensorType::get(transposedInputShape, elementType), + packedOperand, reassociationMap) + .getResult(); + transposedMap = + AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), + getPackedVector(mapResults, targetIndices), + input.getContext()); + } else { + for (auto innerDim : targetIndices) { + mapResults.push_back(rewriter.getAffineDimExpr( + innerDimToDomainDim[inputMap.getDimPosition(innerDim)])); + } + transposedMap = AffineMap::get( + inputMap.getNumDims() + innerDimToDomainDim.size(), + inputMap.getNumSymbols(), mapResults, input.getContext()); + } + + return std::make_tuple(packedOperand, packedInput, transposedMap); +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static Value createTransposeAsTensorUnPack(PatternRewriter &rewriter, + Location loc, Value output, + tensor::PackOp packOp, + int tilingFactor) { + Value packedOutput = output; + if (tilingFactor <= 0) { + RankedTensorType outType = output.getType().cast(); + auto elementType = outType.getElementType(); + auto outputShape(outType.getShape()); + int64_t rank = outType.getRank(); + TransposeIndices targetIndices(packOp.getInnerDimsPos()); + + int startDim = + *std::min_element(targetIndices.begin(), targetIndices.end()); + SmallVector expandedOutputShape; + for (int i = 0, e = startDim; i < e; i++) + expandedOutputShape.push_back(outputShape[i]); + for (int i = 0, e = targetIndices.size(); i < e; i++) + expandedOutputShape.push_back(1); + for (int i = startDim, e = rank; i < e; i++) + expandedOutputShape.push_back(outputShape[i]); + + auto reassociationMap = getUntiledPackReassociationMap(targetIndices, rank); + packedOutput = + rewriter + .create( + loc, RankedTensorType::get(expandedOutputShape, elementType), + output, reassociationMap) + .getResult(); + } + + Value empty = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedOutput, packOp.getMixedTiles(), + packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); + + auto unpackedOutput = rewriter.create( + loc, packedOutput, empty, packOp.getInnerDimsPos(), + packOp.getMixedTiles(), packOp.getOuterDimsPerm()); + unpackedOutput->setAttr("__unpack__", rewriter.getUnitAttr()); + return unpackedOutput.getResult(); +} + +static TransposeIndices collectChannelTransposeIndices( + AffineMap map, SmallVector> transposeDimTargets) { + SmallVector channelIndices(transposeDimTargets.size()); + for (auto [index, result] : llvm::enumerate(map.getResults())) { + if (llvm::isa(result)) { + for (auto [channelVec, dimCategory] : + llvm::zip_equal(channelIndices, transposeDimTargets)) { + if (llvm::is_contained(dimCategory, + llvm::cast(result).getPosition())) { + channelVec.push_back(index); + break; + } + } + } + } + + TransposeIndices indices; + for (auto channelVec : channelIndices) indices.append(channelVec); + return indices; +} + +static LogicalResult transposeConvLikeLinalgOp( + PatternRewriter &rewriter, linalg::LinalgOp convOp, int tilingFactor, + ConvBuilderFn convBuilder = defaultConvBuilderFn) { + Location loc = convOp.getLoc(); + + linalg::ConvolutionDimensions convDims; + auto errString = getMatchConvolutionMessage( + linalg::detail::isConvolutionInterfaceImpl(convOp, &convDims)); + if (!errString.empty()) return failure(); + + if (convDims.inputChannel.size() > 1) return failure(); + + if (convDims.outputChannel.size() > 1) return failure(); + + // TODO: Support depthwise convolutions + if (!convDims.depth.empty()) return failure(); + + Value input = convOp->getOperand(0); + Value filter = convOp->getOperand(1); + Value output = convOp->getOperand(2); + + auto inputMap = convOp.getIndexingMapsArray()[0]; + auto filterMap = convOp.getIndexingMapsArray()[1]; + auto outputMap = convOp.getIndexingMapsArray()[2]; + + auto inputIndices = + collectChannelTransposeIndices(inputMap, {convDims.inputChannel}); + auto filterIndices = collectChannelTransposeIndices( + filterMap, {convDims.inputChannel, convDims.outputChannel}); + auto outputIndices = + collectChannelTransposeIndices(outputMap, {convDims.outputChannel}); + + // Don't transpose if there's no change to the op. + if (isInnerIdentityIndices(inputIndices, inputMap.getNumResults()) && + isInnerIdentityIndices(filterIndices, filterMap.getNumResults()) && + isInnerIdentityIndices(outputIndices, outputMap.getNumResults())) + return failure(); + + int nDims = outputMap.getNumDims(); + llvm::DenseMap innerDimsToDomainDims; + for (auto [index, dim] : llvm::enumerate(convDims.inputChannel)) { + innerDimsToDomainDims[dim] = nDims + index; + } + for (auto [index, dim] : llvm::enumerate(convDims.outputChannel)) { + innerDimsToDomainDims[dim] = nDims + index + convDims.inputChannel.size(); + } + + auto [transposedInput, inputPack, transposedInputMap] = + createTransposeAsTensorPack(rewriter, loc, input, inputMap, inputIndices, + tilingFactor, innerDimsToDomainDims); + auto [transposedFilter, filterPack, transposedFilterMap] = + createTransposeAsTensorPack(rewriter, loc, filter, filterMap, + filterIndices, tilingFactor, + innerDimsToDomainDims); + auto [transposedOutput, outputPack, transposedOutputMap] = + createTransposeAsTensorPack(rewriter, loc, output, outputMap, + outputIndices, tilingFactor, + innerDimsToDomainDims); + + // Don't transpose if there's no change to the op. + if (transposedInputMap == inputMap && transposedFilterMap == filterMap && + transposedOutputMap == outputMap) + return failure(); + + Value convDest = transposedOutput; + if (auto fillOp = output.getDefiningOp()) { + if (outputPack) { + auto outputDest = outputPack->getDest().getDefiningOp(); + auto elementType = outputDest.getType().getElementType(); + + auto dimToTileMapping = outputPack->getDimAndTileMapping(); + SmallVector mixedSizes = outputDest.getMixedSizes(); + SmallVector packedSizes; + for (auto [index, size] : llvm::enumerate(mixedSizes)) + if (!dimToTileMapping.count(index) || tilingFactor > 0) + packedSizes.push_back(size); + + auto emptyOp = + rewriter.create(loc, packedSizes, elementType); + + convDest = rewriter + .create(loc, fillOp.getInputs(), + emptyOp.getResult()) + .result(); + } + } + + SmallVector newDimOrder; + SmallVector newIteratorTypes; + if (tilingFactor <= 0) { + newDimOrder.append(convDims.batch); + newDimOrder.append(convDims.outputImage); + newDimOrder.append(convDims.outputChannel); + newDimOrder.append(convDims.filterLoop); + newDimOrder.append(convDims.inputChannel); + } else { + newIteratorTypes.append(convDims.inputChannel.size(), + utils::IteratorType::reduction); + newIteratorTypes.append(convDims.outputChannel.size(), + utils::IteratorType::parallel); + } + + Value transposedConvResult = + convBuilder(rewriter, loc, convOp, transposedInput, transposedFilter, + convDest, transposedInputMap, transposedFilterMap, + transposedOutputMap, newDimOrder, newIteratorTypes); + + Value returnToNCHW = transposedConvResult; + if (outputPack) { + returnToNCHW = createTransposeAsTensorUnPack( + rewriter, loc, transposedConvResult, *outputPack, tilingFactor); + } + + rewriter.replaceOp(convOp, returnToNCHW); + return success(); +} + +namespace { + +//===================================================================== +// Convolution packing patterns +//===================================================================== + +struct ConvertLinalgConvNchwFchw : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgConvNchwFchw(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, convOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgPoolingNchwMax + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgPoolingNchwMax(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::PoolingNchwMaxOp poolOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, poolOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgPoolingNchwSum + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgPoolingNchwSum(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::PoolingNchwSumOp poolOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, poolOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgConvOp : OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + ConvertLinalgConvOp(MLIRContext *context, int tile, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + tilingFactor(tile) {} + + LogicalResult matchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const override { + if (op->hasAttr(fullTileTransposeMarker)) + return transposeConvLikeLinalgOp(rewriter, op, 0); + return transposeConvLikeLinalgOp(rewriter, op, tilingFactor); + } + + private: + int tilingFactor; +}; + +//===================================================================== +// Propagation patterns +//===================================================================== + +class BubbleUpPackThroughPadOp final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto padOp = packOp.getSource().getDefiningOp(); + if (!padOp) return failure(); + + if (!padOp.getResult().hasOneUse()) return failure(); + + // TODO: Enable padding. + if (packOp.getPaddingValue()) return failure(); + + // TODO: Enable outer dims perm. + if (!packOp.getOuterDimsPerm().empty()) return failure(); + + // We want to move the pack not the insert_slice. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(padOp); + + Location loc = padOp->getLoc(); + auto mixedTiles = packOp.getMixedTiles(); + auto innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + if (!packOp.getDest().getDefiningOp()) return failure(); + + // Bail out if one of the padded dimension is a tiled one. + llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); + llvm::SmallBitVector innerDims(paddedDims.size()); + for (int64_t dim : innerDimsPos) innerDims.flip(dim); + if (paddedDims.anyCommon(innerDims)) return failure(); + + Value paddingVal = padOp.getConstantPaddingValue(); + if (!paddingVal) return failure(); + + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, + outerDimsPerm); + Value packedSource = rewriter.create( + loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, + /*padding=*/std::nullopt, outerDimsPerm); + + // If we have `outer_dims_perms` we need to adjust the padded dimensions. + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector highPad = padOp.getMixedHighPad(); + if (!outerDimsPerm.empty()) { + applyPermutationToVector(lowPad, outerDimsPerm); + applyPermutationToVector(highPad, outerDimsPerm); + } + // Add zero padding for the point loops. + size_t pointLoopsSize = innerDimsPos.size(); + lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + + auto newPadOp = rewriter.create( + loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, + padOp.getNofold()); + rewriter.replaceOp(packOp, newPadOp.getResult()); + return success(); + } +}; + +class BubbleUpPackThroughTensorInsertSlice final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto insertSliceOp = + packOp.getSource().getDefiningOp(); + if (!insertSliceOp) return failure(); + + if (!insertSliceOp.getResult().hasOneUse()) return failure(); + + // TODO: Enable rank reduced slice. + if (insertSliceOp.getSourceType().getRank() != + insertSliceOp.getDestType().getRank()) + return failure(); + + // TODO: Enable padding. + if (packOp.getPaddingValue()) return failure(); + + // TODO: Enable outer dims perm. + if (!packOp.getOuterDimsPerm().empty()) return failure(); + + // We want to move the pack not the insert_slice. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(insertSliceOp); + + Location loc = insertSliceOp->getLoc(); + auto mixedTiles = packOp.getMixedTiles(); + auto innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + Value packOpDest = packOp.getDest(); + if (!packOpDest.hasOneUse()) return failure(); + if (auto emptyOp = packOpDest.getDefiningOp()) { + packOpDest = tensor::PackOp::createDestinationTensor( + rewriter, loc, insertSliceOp.getDest(), mixedTiles, innerDimsPos, + outerDimsPerm); + } else { + DominanceInfo dom(insertSliceOp); + if (!dom.properlyDominates(packOpDest, insertSliceOp)) return failure(); + } + + SmallVector mixedSliceTiles(packOp.getMixedTiles()); + + SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); + SmallVector mixedSizes(insertSliceOp.getMixedSizes()); + SmallVector mixedStrides(insertSliceOp.getMixedStrides()); + + for (auto [index, dimPos, mixedTileSize] : + llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), + innerDimsPos, mixedTiles)) { + if (!getConstantIntValue(mixedStrides[dimPos])) return failure(); + + std::optional constTileSize = getConstantIntValue(mixedTileSize); + if (!constTileSize) return failure(); + + std::optional constOffset = + getConstantIntValue(mixedOffsets[dimPos]); + if (!constOffset) return failure(); + + std::optional constSize = + getConstantIntValue(mixedSizes[dimPos]); + if (!constOffset) return failure(); + + int64_t tileSize = *constTileSize; + int64_t offset = *constOffset; + int64_t size = *constSize; + + if ((size % tileSize != 0 || offset % tileSize != 0) && + (offset / tileSize > (size + offset) / tileSize)) + return failure(); + mixedSliceTiles[index] = + rewriter.getI64IntegerAttr(std::min(size, tileSize)); + mixedOffsets[dimPos] = rewriter.getI64IntegerAttr(offset / tileSize); + mixedSizes[dimPos] = + rewriter.getI64IntegerAttr(std::max(size / tileSize, 1)); + + mixedOffsets.push_back(rewriter.getI64IntegerAttr(offset % tileSize)); + mixedSizes.push_back( + rewriter.getI64IntegerAttr(std::min(size, tileSize))); + mixedStrides.push_back(rewriter.getI64IntegerAttr(1)); + } + + Value newDest = packOpDest; + if (!insertSliceOp.getDest().getDefiningOp()) { + newDest = rewriter.create( + loc, insertSliceOp.getDest(), packOpDest, innerDimsPos, mixedTiles, + /*padding=*/std::nullopt, outerDimsPerm); + } + + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, insertSliceOp.getSource(), mixedSliceTiles, innerDimsPos, + outerDimsPerm); + Value packedSlice = rewriter.create( + loc, insertSliceOp.getSource(), empty, innerDimsPos, mixedSliceTiles, + /*padding=*/std::nullopt, outerDimsPerm); + + rewriter.replaceOpWithNewOp( + packOp, packedSlice, newDest, mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; + +//===================================================================== +// Generalization and folding patterns +//===================================================================== + +template +class GeneralizeUntiledPackOrUnPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOrUnPackOpTy op, + PatternRewriter &rewriter) const override { + if (!op.getMixedTiles().empty()) return failure(); + TransposeIndices perm(op.getOuterDimsPerm()); + if (std::is_same::value) + perm = invertIndices(perm); + rewriter.replaceOpWithNewOp(op, op.getSource(), + op.getDest(), perm); + return success(); + } +}; + +static SmallVector getTilingReassociationMap( + int64_t rank, llvm::DenseMap innerDims) { + SmallVector map; + int64_t nTiled = 0; + for (int64_t i = 0, e = rank; i < e; i++) { + if (innerDims.count(i)) { + map.push_back({i + nTiled++, i + nTiled}); + continue; + } + map.push_back({i + nTiled}); + } + return map; +} + +class GeneralizeUnPermutedPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + if (!packOp.getOuterDimsPerm().empty()) return failure(); + if (packOp.getPaddingValue()) return failure(); + + RankedTensorType srcType = + packOp.getSource().getType().cast(); + int64_t rank = srcType.getRank(); + auto innerDimsPos = packOp.getInnerDimsPos(); + llvm::DenseMap innerDims; + for (auto [index, innerDim] : llvm::enumerate(innerDimsPos)) + innerDims[innerDim] = index; + + llvm::DenseMap innerDimsToExpandedDims; + TransposeIndices perm; + int64_t nTiled = 0; + for (int i = 0, e = rank; i < e; i++) { + perm.push_back(i + nTiled); + if (innerDims.count(i)) innerDimsToExpandedDims[i] = i + ++nTiled; + } + for (auto i : innerDimsPos) perm.push_back(innerDimsToExpandedDims[i]); + + RankedTensorType destType = + packOp.getDest().getType().cast(); + SmallVector destShape(destType.getShape()); + applyPermutationToVector(destShape, invertPermutationVector(perm)); + + auto expand = rewriter.create( + packOp.getLoc(), + RankedTensorType::get(destShape, destType.getElementType()), + packOp.getSource(), getTilingReassociationMap(rank, innerDims)); + + rewriter.replaceOpWithNewOp(packOp, expand, + packOp.getDest(), perm); + return success(); + } +}; + +class GeneralizeUnPermutedUnPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + if (!unpackOp.getOuterDimsPerm().empty()) return failure(); + + if (!unpackOp.getDest().getDefiningOp()) return failure(); + + RankedTensorType destType = + unpackOp.getDest().getType().cast(); + int64_t rank = destType.getRank(); + auto innerDimsPos = unpackOp.getInnerDimsPos(); + llvm::DenseMap innerDims; + for (auto [index, innerDim] : llvm::enumerate(innerDimsPos)) + innerDims[innerDim] = index; + + TransposeIndices perm; + for (int i = 0, e = rank; i < e; i++) { + perm.push_back(i); + if (innerDims.count(i)) perm.push_back(rank + innerDims[i]); + } + + Location loc = unpackOp.getLoc(); + SmallVector mixedSizes = + tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()); + applyPermutationToVector(mixedSizes, perm); + auto elType = getElementTypeOrSelf(unpackOp.getDest()); + + auto emptyOp = rewriter.create(loc, mixedSizes, elType); + + Value transpose = rewriter + .create( + loc, unpackOp.getSource(), emptyOp, perm) + ->getResult(0); + + rewriter.replaceOpWithNewOp( + unpackOp, destType, transpose, + getTilingReassociationMap(rank, innerDims)); + return success(); + } +}; + +class GeneralizeLinalgTransposeOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp op, + PatternRewriter &rewriter) const override { + auto linalgOp = cast(*op); + auto transpose = + rewriter + .create( + op.getLoc(), op.getResult().getType(), op.getInput(), + op.getInit(), linalgOp.getIndexingMapsArray(), + linalgOp.getIteratorTypesArray(), + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + rewriter.replaceOp(op, transpose); + return success(); + } +}; + +class FoldCancellingUnPackPackOps final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + return tensor::UnPackOp::canonicalize(unpackOp, rewriter); + } +}; + +class FoldCancellingPackUnPackOps final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + return tensor::PackOp::canonicalize(packOp, rewriter); + } +}; + +struct ConvertConvToChannelsLastPass + : public ConvertConvToChannelsLastBase { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + LogicalResult initializeOptions(StringRef options) override { + if (failed(Pass::initializeOptions(options))) { + return failure(); + } + tilingFactor = tileSize; + return success(); + } + + void runOnOperation() override { + auto op = getOperation(); + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + if (tilingFactor < 0) { + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + } + patterns.insert(context, tilingFactor); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return true; }); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.add(context); + patterns.insert(context); + patterns.insert>(context); + patterns.insert>( + context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + } + + private: + int64_t tilingFactor; +}; + +} // namespace + +std::unique_ptr createConvertConvToChannelsLastPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp new file mode 100644 index 000000000000..f3a56404375e --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp @@ -0,0 +1,64 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::Preprocessing { + +namespace { + +template +class GeneralizeTargetNamedOp final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LinalgOpType linalgOp, + PatternRewriter &rewriter) const override { + FailureOr genericOp = + linalg::generalizeNamedOp(rewriter, linalgOp); + if (failed(genericOp)) return failure(); + return success(); + } +}; + +struct GeneralizeConvolutionsPass + : GeneralizeConvolutionsBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(&getContext()); + patterns.insert>(context); + patterns.insert>(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr createGeneralizeConvolutionsPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index fcf5ddf86499..b6158e0caab2 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -10,6 +10,7 @@ #include #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -19,11 +20,21 @@ namespace mlir::iree_compiler::Preprocessing { /// using im2col tranformation. std::unique_ptr createConvertConv2DToImg2ColPass(); +// Creates a pass to convert linalg NCHW Convolutions to NHWC. +std::unique_ptr> +createConvertConvNchwToNhwcPass(); + /// Moves the body of the entire function into a single dispatch. std::unique_ptr> createMakeSingleDispatchForFunctionPass(); -/// A pass to pad linalg ops to the next integer multiple of `paddingSize`. +// A pass to generalize all conv-like ops. +std::unique_ptr createGeneralizeConvolutionsPass(); + +// Creates a pass to convert convolutions to channels last and propagate. +std::unique_ptr createConvertConvToChannelsLastPass(); + +// A pass to pad linalg ops to the next integer multiple of `paddingSize`. std::unique_ptr createPadLinalgOpsToIntegerMultiplePass(); //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index ae67e754cfcc..0f523fdcd3c9 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -15,12 +15,37 @@ def ConvertConv2DToImg2Col : let constructor = "mlir::iree_compiler::Preprocessing::createConvertConv2DToImg2ColPass()"; } +def ConvertConvNchwToNhwc : + InterfacePass<"iree-flow-convert-conv-nchw-to-nhwc", "mlir::FunctionOpInterface"> { + let summary = "Convert linalg NCHW Convolutions to NHWC"; + let constructor = + "mlir::iree_compiler::Preprocessing::createConvertConvNchwToNhwcPass()"; +} + def MakeSingleDispatchForFunction : Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> { let summary = "Convert entire function into a single dispatch"; let constructor = "mlir::iree_compiler::Preprocessing::createMakeSingleDispatchForFunctionPass()"; } +def GeneralizeConvolutions : + Pass<"iree-preprocessing-generalize-convolutions", ""> { + let summary = "Generalize all convolution ops"; + let constructor = "mlir::iree_compiler::Preprocessing::createGeneralizeConvolutionsPass()"; +} + +def ConvertConvToChannelsLast : + Pass<"iree-preprocessing-convert-conv-to-channels-last", ""> { + let summary = "Convert linalg convolutions to channels last."; + let constructor = + "mlir::iree_compiler::Preprocessing::createConvertConvToChannelsLastPass()"; + let options = [ + Option<"tileSize", "tile-size", "int", + /*default=*/"0", + "Specify the tiling factor">, + ]; +} + def PadLinalgOps : Pass<"iree-preprocessing-pad-linalg-ops", ""> { let summary = "Pad linalg ops to the next integer multiple of paddingSize."; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index d11313bc5a98..044ffc329326 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "conv2d_nchw_to_nhwc.mlir", "conv2d_to_img2col.mlir", "make_single_dispatch_for_function.mlir", "pad_linalg_ops.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index 19425cb4944d..c47bc7b0ea22 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "conv2d_nchw_to_nhwc.mlir" "conv2d_to_img2col.mlir" "make_single_dispatch_for_function.mlir" "pad_linalg_ops.mlir" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir new file mode 100644 index 000000000000..b7ab1af35a69 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir @@ -0,0 +1,40 @@ +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))" %s | FileCheck %s + +func.func @batch_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1, d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> +// CHECK: @batch_conv +// CHECK: %[[INPUT:.+]]: tensor<8x4x16x16xf32> +// CHECK: %[[FILTER:.+]]: tensor<16x4x3x3xf32> +// CHECK: %[[OUTPUT:.+]]: tensor<8x16x14x14xf32> +// CHECK: %[[INIT_INPUT_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x16x16x4xf32> +// CHECK: %[[TRANSPOSED_INPUT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>) outs(%[[INIT_INPUT_TRANSPOSE]] : tensor<8x16x16x4xf32>) +// CHECK: %[[INIT_FILTER_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<3x3x4x16xf32> +// CHECK: %[[TRANSPOSED_FILTER:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP2]] +// CHECK-SAME: ins(%[[FILTER]] : tensor<16x4x3x3xf32>) outs(%[[INIT_FILTER_TRANSPOSE]] : tensor<3x3x4x16xf32>) +// CHECK: %[[INIT_OUTPUT_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x14x14x16xf32> +// CHECK: %[[TRANSPOSED_OUTPUT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK-SAME: ins(%[[OUTPUT]] : tensor<8x16x14x14xf32>) outs(%[[INIT_OUTPUT_TRANSPOSE]] : tensor<8x14x14x16xf32>) +// CHECK: %[[TRANSPOSED_RESULT:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_FILTER]] : tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>) outs(%[[TRANSPOSED_OUTPUT]] : tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> +// CHECK: %[[INIT_RESULT:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x16x14x14xf32> +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP3]] +// CHECK-SAME: ins(%[[TRANSPOSED_RESULT]] : tensor<8x14x14x16xf32>) outs(%[[INIT_RESULT]] : tensor<8x16x14x14xf32>) +// CHECK: return %[[RESULT]] : tensor<8x16x14x14xf32> diff --git a/experimental/rocm/CMakeLists.txt b/experimental/rocm/CMakeLists.txt index 7a48dabe8fda..9cbc6ee14633 100644 --- a/experimental/rocm/CMakeLists.txt +++ b/experimental/rocm/CMakeLists.txt @@ -61,6 +61,8 @@ iree_cc_library( "pipeline_layout.h" "status_util.c" "status_util.h" + "stream_command_buffer.c" + "stream_command_buffer.h" "tracing.c" "tracing.h" INCLUDES @@ -75,8 +77,11 @@ iree_cc_library( iree::base::internal::flatcc::parsing iree::base::internal::synchronization iree::hal + iree::hal::utils::collective_batch + iree::hal::utils::deferred_command_buffer iree::hal::utils::file_transfer iree::hal::utils::memory_file + iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::schemas::rocm_executable_def_c_fbs COPTS diff --git a/experimental/rocm/api.h b/experimental/rocm/api.h index 68fa1913bf2f..7949ac407afa 100644 --- a/experimental/rocm/api.h +++ b/experimental/rocm/api.h @@ -16,6 +16,46 @@ extern "C" { #endif // __cplusplus +//===----------------------------------------------------------------------===// +// iree_hal_rocm_device_t +//===----------------------------------------------------------------------===// + +// Defines how command buffers are recorded and executed. +typedef enum iree_hal_rocm_command_buffer_mode_e { + // Command buffers are recorded into ROCM null stream. + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT = 0, + // Command buffers are directly issued against ROCM stream. + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM = 1, +} iree_hal_rocm_command_buffer_mode_t; + +// Parameters configuring an iree_hal_rocm_device_t. +// Must be initialized with iree_hal_rocm_device_params_initialize prior to use. +typedef struct iree_hal_rocm_device_params_t { + + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; + + // Specifies how command buffers are recorded and executed. + iree_hal_rocm_command_buffer_mode_t command_buffer_mode; + + // Enables tracing of command buffers when IREE tracing is enabled. + // May take advantage of additional extensions for more accurate timing or + // hardware-specific performance counters. + // + // NOTE: tracing has a non-trivial overhead and will skew the timing of + // submissions and introduce false barriers between dispatches. Use this to + // identify slow dispatches and refine from there; be wary of whole-program + // tracing with this enabled. + bool stream_tracing; + +} iree_hal_rocm_device_params_t; + +// Initializes |out_params| to default values. +IREE_API_EXPORT void iree_hal_rocm_device_params_initialize( + iree_hal_rocm_device_params_t* out_params); + //===----------------------------------------------------------------------===// // iree_hal_rocm_driver_t //===----------------------------------------------------------------------===// @@ -35,6 +75,7 @@ IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize( // |out_driver| must be released by the caller (see |iree_hal_driver_release|). IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t *options, iree_allocator_t host_allocator, iree_hal_driver_t **out_driver); diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index 785f0edc9ea3..b28acef5471b 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -25,16 +25,19 @@ RC_PFN_DECL(hipInit, unsigned int) RC_PFN_DECL(hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **) +RC_PFN_DECL(hipMemAdvise, const void *, size_t, int, int) RC_PFN_DECL(hipMemset, void *, int, size_t) RC_PFN_DECL(hipMemsetAsync, void *, int, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD32Async, void *, int, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD16Async, void *, short, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD8Async, void *, char, size_t, hipStream_t) RC_PFN_DECL(hipMemcpy, void *, const void *, size_t, hipMemcpyKind) +RC_PFN_DECL(hipMemcpyHtoDAsync, hipDeviceptr_t, void *, size_t, hipStream_t) RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind, hipStream_t) RC_PFN_DECL(hipMalloc, void **, size_t) RC_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int) +RC_PFN_DECL(hipMemPrefetchAsync, const void *, size_t, int, hipStream_t) RC_PFN_DECL(hipFree, void *) RC_PFN_DECL(hipHostFree, void *) RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int) diff --git a/experimental/rocm/registration/driver_module.c b/experimental/rocm/registration/driver_module.c index fcdadfe3c112..f1e180a91803 100644 --- a/experimental/rocm/registration/driver_module.c +++ b/experimental/rocm/registration/driver_module.c @@ -11,6 +11,19 @@ #include "experimental/rocm/api.h" #include "iree/base/api.h" +#include "iree/base/internal/flags.h" + +// Force using ROCM streams until we support command buffer caching to avoid the +// overhead of graph creation. +IREE_FLAG( + bool, rocm_use_streams, true, + "Use ROCM streams for executing command buffers (instead of graphs)."); + +IREE_FLAG( + bool, rocm_tracing, true, + "Enables tracing of stream events when Tracy instrumentation is enabled.\n" + "Severely impacts benchmark timings and should only be used when\n" + "analyzing dispatch timings."); static iree_status_t iree_hal_rocm_driver_factory_enumerate( void *self, iree_host_size_t *out_driver_info_count, @@ -36,10 +49,18 @@ static iree_status_t iree_hal_rocm_driver_factory_try_create( (int)driver_name.size, driver_name.data); } IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_rocm_device_params_t default_params; + iree_hal_rocm_device_params_initialize(&default_params); + if (FLAG_rocm_use_streams) { + default_params.command_buffer_mode = + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM; + } + default_params.stream_tracing = FLAG_rocm_tracing; iree_hal_rocm_driver_options_t driver_options; iree_hal_rocm_driver_options_initialize(&driver_options); iree_status_t status = iree_hal_rocm_driver_create( - driver_name, &driver_options, host_allocator, out_driver); + driver_name, &default_params, &driver_options, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; } diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c index 3c63c71ec1bd..dbd0ea4b9936 100644 --- a/experimental/rocm/rocm_allocator.c +++ b/experimental/rocm/rocm_allocator.c @@ -15,8 +15,11 @@ typedef struct iree_hal_rocm_allocator_t { iree_hal_resource_t resource; - iree_hal_device_t* base_device; iree_hal_rocm_context_wrapper_t* context; + hipDevice_t device; + hipStream_t stream; + + bool supports_concurrent_managed_access; IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;) } iree_hal_rocm_allocator_t; @@ -30,10 +33,30 @@ static iree_hal_rocm_allocator_t* iree_hal_rocm_allocator_cast( } iree_status_t iree_hal_rocm_allocator_create( - iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_context_wrapper_t* context, hipDevice_t device, hipStream_t stream, iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(context); IREE_TRACE_ZONE_BEGIN(z0); + + // To support device-local + host-visible memory we need concurrent managed + // access indicating that the host and devices can concurrently access the + // device memory. If we don't have this feature then we fall back to forcing + // all device-local + host-visible memory into host-local + device-visible + // page-locked memory. The compiler tries to avoid this for high-traffic + // buffers except for readback staging buffers. + int supports_concurrent_managed_access = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, ROCM_RESULT_TO_STATUS( + context->syms, + hipDeviceGetAttribute( + &supports_concurrent_managed_access, + hipDeviceAttributeConcurrentManagedAccess, device), + "hipDeviceGetAttribute")); + IREE_TRACE_ZONE_APPEND_TEXT( + z0, supports_concurrent_managed_access + ? "has CONCURRENT_MANAGED_ACCESS" + : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on " + "device-local + host-visible memory)"); iree_hal_rocm_allocator_t* allocator = NULL; iree_status_t status = iree_allocator_malloc( context->host_allocator, sizeof(*allocator), (void**)&allocator); @@ -41,6 +64,9 @@ iree_status_t iree_hal_rocm_allocator_create( iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable, &allocator->resource); allocator->context = context; + allocator->device = device; + allocator->stream = stream; + allocator->supports_concurrent_managed_access = supports_concurrent_managed_access !=0; *out_allocator = (iree_hal_allocator_t*)allocator; } @@ -87,24 +113,31 @@ static iree_status_t iree_hal_rocm_allocator_query_memory_heaps( iree_host_size_t capacity, iree_hal_allocator_memory_heap_t* IREE_RESTRICT heaps, iree_host_size_t* IREE_RESTRICT out_count) { - const iree_host_size_t count = 3; + iree_hal_rocm_allocator_t* allocator = + iree_hal_rocm_allocator_cast(base_allocator); + + // TODO(benvanik): check CU_DEVICE_ATTRIBUTE_INTEGRATED and return a unified + // set of heaps (likely still a cached and uncached, at minimum). + iree_host_size_t count = 3; + if (allocator->supports_concurrent_managed_access) { + ++count; // device-local | host-visible + } if (out_count) *out_count = count; if (capacity < count) { // NOTE: lightweight as this is hit in normal pre-sizing usage. return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); } - // NOTE: this is all a guess - someone who is familiar with rocm will want - // to refine this further. - // Don't think there's a query for these. // Max allocation size may be much smaller in certain memory types such as // page-locked memory and it'd be good to enforce that. const iree_device_size_t max_allocation_size = ~(iree_device_size_t)0; const iree_device_size_t min_alignment = 64; + int i = 0; + // Device-local memory (dispatch resources): - heaps[0] = (iree_hal_allocator_memory_heap_t){ + heaps[i++] = (iree_hal_allocator_memory_heap_t){ .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH, @@ -112,27 +145,46 @@ static iree_status_t iree_hal_rocm_allocator_query_memory_heaps( .min_alignment = min_alignment, }; + if (allocator->supports_concurrent_managed_access) { + // Device-local managed memory with host mapping support: + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + } + // Write-combined page-locked host-local memory (upload): - heaps[1] = (iree_hal_allocator_memory_heap_t){ - .type = - IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_COHERENT, - .allowed_usage = - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, .max_allocation_size = max_allocation_size, .min_alignment = min_alignment, }; // Cached page-locked host-local memory (download): - heaps[2] = (iree_hal_allocator_memory_heap_t){ - .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_COHERENT | IREE_HAL_MEMORY_TYPE_HOST_CACHED, - .allowed_usage = - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, .max_allocation_size = max_allocation_size, .min_alignment = min_alignment, }; + IREE_ASSERT(i == count); return iree_ok_status(); } @@ -141,22 +193,46 @@ iree_hal_rocm_allocator_query_buffer_compatibility( iree_hal_allocator_t* IREE_RESTRICT base_allocator, iree_hal_buffer_params_t* IREE_RESTRICT params, iree_device_size_t* IREE_RESTRICT allocation_size) { + iree_hal_rocm_allocator_t* allocator = + iree_hal_rocm_allocator_cast(base_allocator); + // All buffers can be allocated on the heap. iree_hal_buffer_compatibility_t compatibility = IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; - if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { - compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + // Buffers are importable in ROCM under most cases, though performance may + // vary wildly. We don't fully verify that the buffer parameters are + // self-consistent and just look at whether we can get a device pointer. + if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE; } // Buffers can only be used on the queue if they are device visible. if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) { compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; } } + // If concurrent managed access is not supported then make device-local + + // host-visible allocations fall back to host-local + device-visible + // page-locked memory. This will be significantly slower for the device to + // access but the compiler only uses this type for readback staging buffers + // and it's better to function than function fast. + if (!allocator->supports_concurrent_managed_access && + iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE; + params->type &= ~(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE); + params->type |= + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE; + } + // We are now optimal. params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL; @@ -209,6 +285,24 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer( status = ROCM_RESULT_TO_STATUS( allocator->context->syms, hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); + if (iree_status_is_ok(status)) { + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetPreferredLocation, allocator->device)); + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetCoarseGrain, allocator->device)); + } + if (iree_status_is_ok(status) && + allocator->supports_concurrent_managed_access) { + // Prefetch the buffer on the GPU device. + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemPrefetchAsync(device_ptr, allocation_size, allocator->device, + allocator->stream)); + } host_ptr = (void*)device_ptr; } else { // Device only. diff --git a/experimental/rocm/rocm_allocator.h b/experimental/rocm/rocm_allocator.h index a2a89eab2cdd..c735e830b013 100644 --- a/experimental/rocm/rocm_allocator.h +++ b/experimental/rocm/rocm_allocator.h @@ -19,6 +19,8 @@ extern "C" { // Create a ROCM allocator. iree_status_t iree_hal_rocm_allocator_create( iree_hal_rocm_context_wrapper_t* context, + hipDevice_t device, + hipStream_t stream, iree_hal_allocator_t** out_allocator); #ifdef __cplusplus diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index b76f165efa5c..24da6bd7da66 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c @@ -19,8 +19,10 @@ #include "experimental/rocm/rocm_allocator.h" #include "experimental/rocm/rocm_event.h" #include "experimental/rocm/status_util.h" +#include "experimental/rocm//stream_command_buffer.h" #include "experimental/rocm/tracing.h" #include "iree/base/internal/arena.h" +#include "iree/hal/utils/deferred_command_buffer.h" #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" @@ -40,6 +42,9 @@ typedef struct iree_hal_rocm_device_t { // to ensure the symbols remains valid. iree_hal_driver_t* driver; + // Parameters used to control device behavior. + iree_hal_rocm_device_params_t params; + hipDevice_t device; // TODO: support multiple streams. @@ -50,6 +55,10 @@ typedef struct iree_hal_rocm_device_t { // Optional provider used for creating/configuring collective channels. iree_hal_channel_provider_t* channel_provider; + + // Cache of the direct stream command buffer initialized when in stream mode. + // TODO: have one cached per stream once there are multiple streams. + iree_hal_command_buffer_t* stream_command_buffer; } iree_hal_rocm_device_t; static const iree_hal_device_vtable_t iree_hal_rocm_device_vtable; @@ -60,11 +69,21 @@ static iree_hal_rocm_device_t* iree_hal_rocm_device_cast( return (iree_hal_rocm_device_t*)base_value; } +IREE_API_EXPORT void iree_hal_rocm_device_params_initialize( + iree_hal_rocm_device_params_t* out_params) { + memset(out_params, 0, sizeof(*out_params)); + out_params->arena_block_size = 32*1024; + out_params->command_buffer_mode = IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT; + out_params->stream_tracing = false; +} + static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_command_buffer_release(device->stream_command_buffer); + // There should be no more buffers live that use the allocator. iree_hal_allocator_release(device->device_allocator); @@ -75,6 +94,8 @@ static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { ROCM_IGNORE_ERROR(device->context_wrapper.syms, hipStreamDestroy(device->stream)); + iree_arena_block_pool_deinitialize(&device->block_pool); + // Finally, destroy the device. iree_hal_driver_release(device->driver); @@ -85,9 +106,9 @@ static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { static iree_status_t iree_hal_rocm_device_create_internal( iree_hal_driver_t* driver, iree_string_view_t identifier, - hipDevice_t rocm_device, hipStream_t stream, hipCtx_t context, - iree_hal_rocm_dynamic_symbols_t* syms, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { + const iree_hal_rocm_device_params_t* params, hipDevice_t rocm_device, + hipStream_t stream, hipCtx_t context, iree_hal_rocm_dynamic_symbols_t* syms, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_hal_rocm_device_t* device = NULL; iree_host_size_t total_size = sizeof(*device) + identifier.size; IREE_RETURN_IF_ERROR( @@ -99,20 +120,37 @@ static iree_status_t iree_hal_rocm_device_create_internal( uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); buffer_ptr += iree_string_view_append_to_buffer( identifier, &device->identifier, (char*)buffer_ptr); + device->params = *params; device->device = rocm_device; device->stream = stream; device->context_wrapper.rocm_context = context; device->context_wrapper.rocm_device = rocm_device; device->context_wrapper.host_allocator = host_allocator; + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->block_pool); device->context_wrapper.syms = syms; // Enable tracing for the (currently only) stream - no-op if disabled. - iree_status_t status = iree_hal_rocm_tracing_context_allocate( + iree_status_t status = iree_ok_status(); + if (device->params.stream_tracing) { + status = iree_hal_rocm_tracing_context_allocate( &device->context_wrapper, device->identifier, stream, &device->block_pool, host_allocator, &device->tracing_context); + } if (iree_status_is_ok(status)) { status = iree_hal_rocm_allocator_create(&device->context_wrapper, + device->device, device->stream, &device->device_allocator); } + if (iree_status_is_ok(status) && + params->command_buffer_mode == IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM) { + status = iree_hal_rocm_stream_command_buffer_create( + (iree_hal_device_t*)device, &device->context_wrapper, + device->tracing_context, + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION | + IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED, + IREE_HAL_COMMAND_CATEGORY_ANY, /*binding_capacity=*/0, device->stream, + &device->block_pool, &device->stream_command_buffer); + } if (iree_status_is_ok(status)) { *out_device = (iree_hal_device_t*)device; } else { @@ -123,10 +161,12 @@ static iree_status_t iree_hal_rocm_device_create_internal( iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* params, iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(params); IREE_TRACE_ZONE_BEGIN(z0); hipCtx_t context; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -140,8 +180,8 @@ iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, syms, hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); if (iree_status_is_ok(status)) { - status = iree_hal_rocm_device_create_internal(driver, identifier, device, - stream, context, syms, + status = iree_hal_rocm_device_create_internal(driver, identifier, params, + device, stream, context, syms, host_allocator, out_device); } if (!iree_status_is_ok(status)) { @@ -228,10 +268,21 @@ static iree_status_t iree_hal_rocm_device_create_command_buffer( iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); - return iree_hal_rocm_direct_command_buffer_create( - base_device, &device->context_wrapper, device->tracing_context, mode, - command_categories, queue_affinity, binding_capacity, &device->block_pool, - out_command_buffer); + switch (device->params.command_buffer_mode) { + case IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT: + return iree_hal_rocm_direct_command_buffer_create( + base_device, &device->context_wrapper, device->tracing_context, mode, + command_categories, queue_affinity, binding_capacity, &device->block_pool, + out_command_buffer); + case IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM: + return iree_hal_deferred_command_buffer_create( + base_device, mode, command_categories, binding_capacity, + &device->block_pool, iree_hal_device_host_allocator(base_device), + out_command_buffer); + default: + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid command buffer mode"); + } } static iree_status_t iree_hal_rocm_device_create_descriptor_set_layout( @@ -383,8 +434,28 @@ static iree_status_t iree_hal_rocm_device_queue_execute( // synchronizes after every submit. // TODO(raikonenfnu): currently run on default/null stream, when cmd buffer // stream work with device->stream, we'll change + for (iree_host_size_t i = 0; i < command_buffer_count; i++) { + iree_hal_command_buffer_t* command_buffer = command_buffers[i]; + if (iree_hal_rocm_stream_command_buffer_isa(command_buffer)) { + // Nothing to do for an inline command buffer; all the work has already + // been submitted. When we support semaphores we'll still need to signal + // their completion but do not have to worry about any waits: if there + // were waits we wouldn't have been able to execute inline! + } else if (iree_hal_rocm_direct_command_buffer_isa(command_buffer)) { + IREE_TRACE_ZONE_BEGIN_NAMED(z0, "hipStreamSynchronize"); + ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(0), + "hipStreamSynchronize"); + iree_hal_rocm_tracing_context_collect(device->tracing_context); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } else { + IREE_RETURN_IF_ERROR(iree_hal_deferred_command_buffer_apply( + command_buffers[i], device->stream_command_buffer, + iree_hal_buffer_binding_table_empty())); + } + } IREE_TRACE_ZONE_BEGIN_NAMED(z0, "hipStreamSynchronize"); - ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(0), + ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(device->stream), "hipStreamSynchronize"); iree_hal_rocm_tracing_context_collect(device->tracing_context); IREE_TRACE_ZONE_END(z0); diff --git a/experimental/rocm/rocm_device.h b/experimental/rocm/rocm_device.h index 083f4c7cddb6..7abd4e67ce36 100644 --- a/experimental/rocm/rocm_device.h +++ b/experimental/rocm/rocm_device.h @@ -19,6 +19,7 @@ extern "C" { // Creates a device that owns and manages its own hipContext. iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* params, iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, diff --git a/experimental/rocm/rocm_driver.c b/experimental/rocm/rocm_driver.c index bcec506e2f86..5b67fdc2f3a0 100644 --- a/experimental/rocm/rocm_driver.c +++ b/experimental/rocm/rocm_driver.c @@ -21,6 +21,7 @@ typedef struct iree_hal_rocm_driver_t { // We allow overriding so that multiple ROCM versions can be exposed in the // same process. iree_string_view_t identifier; + iree_hal_rocm_device_params_t default_params; int default_device_index; // ROCM symbols. iree_hal_rocm_dynamic_symbols_t syms; @@ -49,6 +50,7 @@ IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize( static iree_status_t iree_hal_rocm_driver_create_internal( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { iree_hal_rocm_driver_t* driver = NULL; @@ -60,6 +62,8 @@ static iree_status_t iree_hal_rocm_driver_create_internal( iree_string_view_append_to_buffer( identifier, &driver->identifier, (char*)driver + total_size - identifier.size); + memcpy(&driver->default_params, default_params, + sizeof(driver->default_params)); driver->default_device_index = options->default_device_index; iree_status_t status = iree_hal_rocm_dynamic_symbols_initialize(host_allocator, &driver->syms); @@ -84,14 +88,16 @@ static void iree_hal_rocm_driver_destroy(iree_hal_driver_t* base_driver) { IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(default_params); IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_rocm_driver_create_internal( - identifier, options, host_allocator, out_driver); + identifier, default_params, options, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; @@ -286,7 +292,7 @@ static iree_status_t iree_hal_rocm_driver_create_device_by_id( // Attempt to create the device. iree_status_t status = - iree_hal_rocm_device_create(base_driver, device_name, &driver->syms, + iree_hal_rocm_device_create(base_driver, device_name, &driver->default_params, &driver->syms, device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); diff --git a/experimental/rocm/stream_command_buffer.c b/experimental/rocm/stream_command_buffer.c new file mode 100644 index 000000000000..f6d98df63275 --- /dev/null +++ b/experimental/rocm/stream_command_buffer.c @@ -0,0 +1,562 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/rocm/stream_command_buffer.h" + +#include "experimental/rocm/rocm_buffer.h" +#include "experimental/rocm/rocm_event.h" +#include "experimental/rocm/native_executable.h" +#include "experimental/rocm/pipeline_layout.h" +#include "experimental/rocm/status_util.h" +#include "iree/hal/utils/collective_batch.h" +#include "iree/hal/utils/resource_set.h" + +#define IREE_HAL_ROCM_MAX_BINDING_COUNT 64 +// Kernel arguments contains binding and push constants. +#define IREE_HAL_ROCM_MAX_KERNEL_ARG 128 + +typedef struct { + iree_hal_command_buffer_t base; + iree_hal_rocm_context_wrapper_t* context; + iree_hal_rocm_tracing_context_t* tracing_context; + hipStream_t stream; + + // Maintains a reference to all resources used within the command buffer. + // Reset on each begin. + iree_hal_resource_set_t* resource_set; + + // Staging arena used for host->device transfers. + // Used for when we need ROCM to be able to reference memory as it performs + // asynchronous operations. + iree_arena_allocator_t arena; + + // Iteratively constructed batch of collective operations. + iree_hal_collective_batch_t collective_batch; + + int32_t push_constant[IREE_HAL_ROCM_MAX_PUSH_CONSTANT_COUNT]; + + // Keep track of the current set of kernel arguments. + void* current_descriptor[IREE_HAL_ROCM_MAX_KERNEL_ARG]; + hipDeviceptr_t* device_ptrs[IREE_HAL_ROCM_MAX_KERNEL_ARG]; +} iree_hal_rocm_stream_command_buffer_t; + +static const iree_hal_command_buffer_vtable_t + iree_hal_rocm_stream_command_buffer_vtable; + +static iree_hal_rocm_stream_command_buffer_t* +iree_hal_rocm_stream_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_stream_command_buffer_vtable); + return (iree_hal_rocm_stream_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_rocm_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_tracing_context_t* tracing_context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hipStream_t stream, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + if (binding_capacity > 0) { + // TODO(#10144): support indirect command buffers with binding tables. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_rocm_stream_command_buffer_t* command_buffer = NULL; + iree_status_t status = + iree_allocator_malloc(context->host_allocator, sizeof(*command_buffer), + (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_command_buffer_initialize( + device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + binding_capacity, &iree_hal_rocm_stream_command_buffer_vtable, + &command_buffer->base); + command_buffer->context = context; + command_buffer->tracing_context = tracing_context; + command_buffer->stream = stream; + iree_arena_initialize(block_pool, &command_buffer->arena); + for (size_t i = 0; i < IREE_HAL_ROCM_MAX_KERNEL_ARG; i++) { + command_buffer->current_descriptor[i] = &command_buffer->device_ptrs[i]; + } + + status = iree_hal_resource_set_allocate(block_pool, + &command_buffer->resource_set); + } + if (iree_status_is_ok(status)) { + iree_hal_collective_batch_initialize(&command_buffer->arena, + command_buffer->resource_set, + &command_buffer->collective_batch); + } + + *out_command_buffer = &command_buffer->base; + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_rocm_stream_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); + iree_hal_resource_set_free(command_buffer->resource_set); + iree_arena_deinitialize(&command_buffer->arena); + iree_allocator_free(command_buffer->context->host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +bool iree_hal_rocm_stream_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_resource_is(&command_buffer->resource, + &iree_hal_rocm_stream_command_buffer_vtable); +} + +// Flushes any pending batched collective operations. +// Must be called before any other non-collective nodes are added to the graph +// or a barrier is encountered. +static iree_status_t iree_hal_rocm_stream_command_buffer_flush_collectives( + iree_hal_rocm_stream_command_buffer_t* command_buffer) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Collectives not implemented on ROCM"); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, + /*file_name=*/NULL, 0, + /*line=*/0, /*func_name=*/NULL, 0, "iree_hal_rocm_stream_command_buffer", + strlen("iree_hal_rocm_stream_command_buffer")); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Reset the arena as there should be nothing using it now that we've + // dispatched all our operations inline. + // NOTE: the resource set may contain resources we need to drop as we don't + // need to keep them live any longer than it takes to schedule the + // operations. In a real command buffer we would be this stream command + // buffer is strictly used to perform inline execution/replay of + // deferred command buffers that are retaining the resources already. + // NOTE: reseting the arena invalidates the collective batch. + iree_arena_reset(&command_buffer->arena); + iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); + iree_hal_resource_set_free(command_buffer->resource_set); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_allocate( + command_buffer->arena.block_pool, &command_buffer->resource_set)); + iree_hal_collective_batch_initialize(&command_buffer->arena, + command_buffer->resource_set, + &command_buffer->collective_batch); + + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); + + return iree_ok_status(); +} + +static void iree_hal_rocm_stream_command_buffer_begin_debug_group( + iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, + iree_hal_label_color_t label_color, + const iree_hal_label_location_t* location) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, + location ? location->file.data : NULL, location ? location->file.size : 0, + location ? location->line : 0, /*func_name=*/NULL, 0, label.data, + label.size); + + // TODO: pass along to CUPTI if available. +} + +static void iree_hal_rocm_stream_command_buffer_end_debug_group( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + // TODO: pass along to CUPTI if available. + + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + // We could mark the memory as invalidated so that if managed ROCM does not + // try to copy it back to the host. + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + hipDeviceptr_t dst = + (hipDeviceptr_t)((uintptr_t)target_device_buffer + target_offset); + + size_t num_elements = length / pattern_length; + switch (pattern_length) { + case 4: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD32Async(dst, *(const uint32_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD32Async"); + break; + } + case 2: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD16Async(dst, *(const uint16_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD16Async"); + break; + } + case 1: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD8Async(dst, *(const uint8_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD8Async"); + break; + } + default: + return iree_make_status(IREE_STATUS_INTERNAL, + "unsupported fill pattern length"); + } + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Allocate scratch space in the arena for the data and copy it in. + // The update buffer API requires that the command buffer capture the host + // memory at the time the method is called in case the caller wants to reuse + // the memory. Because ROCM memcpys are async if we didn't copy it's possible + // for the reused memory to change before the stream reaches the copy + // operation and get the wrong data. + const uint8_t* src = (const uint8_t*)source_buffer + source_offset; + if (command_buffer->arena.block_pool) { + uint8_t* storage = NULL; + IREE_RETURN_IF_ERROR( + iree_arena_allocate(&command_buffer->arena, length, (void**)&storage)); + memcpy(storage, src, length); + src = storage; + } + + // Issue the copy using the scratch memory as the source. + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + hipDeviceptr_t dst = (hipDeviceptr_t)((uintptr_t)target_device_buffer + + iree_hal_buffer_byte_offset(target_buffer) + target_offset); + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemcpyHtoDAsync(dst, (void*)src, length, command_buffer->stream), + "hipMemcpyHtoDAsync"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + hipDeviceptr_t source_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(source_buffer)); + source_offset += iree_hal_buffer_byte_offset(source_buffer); + hipDeviceptr_t dst = (hipDeviceptr_t)((uintptr_t)target_device_buffer + target_offset); + hipDeviceptr_t src = (hipDeviceptr_t)((uintptr_t)source_device_buffer + source_offset); + ROCM_RETURN_IF_ERROR(command_buffer->context->syms, + hipMemcpyAsync(dst, src, length, hipMemcpyDeviceToDevice, command_buffer->stream), + "hipMemcpyAsync"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_collective( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, + iree_hal_collective_op_t op, uint32_t param, + iree_hal_buffer_binding_t send_binding, + iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + return iree_hal_collective_batch_append(&command_buffer->collective_batch, + channel, op, param, send_binding, + recv_binding, element_count); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + + iree_host_size_t constant_base_index = offset / sizeof(int32_t); + for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { + command_buffer->push_constant[i + constant_base_index] = + ((uint32_t*)values)[i]; + } + + return iree_ok_status(); +} + +// Tie together the binding index and its index in |bindings| array. +typedef struct { + uint32_t index; + uint32_t binding; +} iree_hal_rocm_binding_mapping_t; + +// Helper to sort the binding based on their binding index. +static int compare_binding_index(const void* a, const void* b) { + const iree_hal_rocm_binding_mapping_t buffer_a = + *(const iree_hal_rocm_binding_mapping_t*)a; + const iree_hal_rocm_binding_mapping_t buffer_b = + *(const iree_hal_rocm_binding_mapping_t*)b; + return buffer_a.binding < buffer_b.binding ? -1 : 1; +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + + iree_host_size_t base_binding = + iree_hal_rocm_base_binding_index(pipeline_layout, set); + + // Convention with the compiler side. We map bindings to kernel argument. + // We compact the bindings to get a dense set of arguments and keep them order + // based on the binding index. + // Sort the binding based on the binding index and map the array index to the + // argument index. + iree_hal_rocm_binding_mapping_t binding_used[IREE_HAL_ROCM_MAX_BINDING_COUNT]; + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_rocm_binding_mapping_t buffer = {i, bindings[i].binding}; + binding_used[i] = buffer; + } + // TODO: remove this sort - it's thankfully small (1-8 on average) but we + // should be able to avoid it like we do on the CPU side with a bitmap. + qsort(binding_used, binding_count, sizeof(iree_hal_rocm_binding_mapping_t), + compare_binding_index); + assert(binding_count < IREE_HAL_ROCM_MAX_BINDING_COUNT && + "binding count larger than the max expected."); + + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index]; + hipDeviceptr_t device_ptr = + binding.buffer + ? (hipDeviceptr_t)((uintptr_t)iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding.buffer)) + + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset) + : 0; + *((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) = + device_ptr; + } + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_rocm_kernel_params_t kernel_params; + IREE_RETURN_IF_ERROR( + iree_hal_rocm_native_executable_entry_point_kernel_params( + executable, entry_point, &kernel_params)); + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, kernel_params.function_name.data, + kernel_params.function_name.size, + /*line=*/0, /*func_name=*/NULL, 0, kernel_params.function_name.data, + kernel_params.function_name.size); + + // Patch the push constants in the kernel arguments. + iree_host_size_t num_constants = + iree_hal_rocm_pipeline_layout_num_constants(kernel_params.layout); + iree_host_size_t constant_base_index = + iree_hal_rocm_push_constant_index(kernel_params.layout); + for (iree_host_size_t i = 0; i < num_constants; i++) { + *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) = + command_buffer->push_constant[i]; + } + + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipModuleLaunchKernel(kernel_params.function, workgroup_x, workgroup_y, + workgroup_z, kernel_params.block_size[0], + kernel_params.block_size[1], kernel_params.block_size[2], + kernel_params.shared_memory_size, command_buffer->stream, + command_buffer->current_descriptor, NULL), + "hipModuleLaunchKernel"); + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need rocm implementation of dispatch indirect"); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_execute_commands( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_command_buffer_t* base_commands, + iree_hal_buffer_binding_table_t binding_table) { + // TODO(#10144): support indirect command buffers with deferred command + // buffers or graphs. We likely just want to switch to graphs. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); +} + +static const iree_hal_command_buffer_vtable_t + iree_hal_rocm_stream_command_buffer_vtable = { + .destroy = iree_hal_rocm_stream_command_buffer_destroy, + .begin = iree_hal_rocm_stream_command_buffer_begin, + .end = iree_hal_rocm_stream_command_buffer_end, + .begin_debug_group = + iree_hal_rocm_stream_command_buffer_begin_debug_group, + .end_debug_group = iree_hal_rocm_stream_command_buffer_end_debug_group, + .execution_barrier = + iree_hal_rocm_stream_command_buffer_execution_barrier, + .signal_event = iree_hal_rocm_stream_command_buffer_signal_event, + .reset_event = iree_hal_rocm_stream_command_buffer_reset_event, + .wait_events = iree_hal_rocm_stream_command_buffer_wait_events, + .discard_buffer = iree_hal_rocm_stream_command_buffer_discard_buffer, + .fill_buffer = iree_hal_rocm_stream_command_buffer_fill_buffer, + .update_buffer = iree_hal_rocm_stream_command_buffer_update_buffer, + .copy_buffer = iree_hal_rocm_stream_command_buffer_copy_buffer, + .collective = iree_hal_rocm_stream_command_buffer_collective, + .push_constants = iree_hal_rocm_stream_command_buffer_push_constants, + .push_descriptor_set = + iree_hal_rocm_stream_command_buffer_push_descriptor_set, + .dispatch = iree_hal_rocm_stream_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_rocm_stream_command_buffer_dispatch_indirect, + .execute_commands = + iree_hal_rocm_stream_command_buffer_execute_commands, +}; diff --git a/experimental/rocm/stream_command_buffer.h b/experimental/rocm/stream_command_buffer.h new file mode 100644 index 000000000000..691fa63809ff --- /dev/null +++ b/experimental/rocm/stream_command_buffer.h @@ -0,0 +1,49 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ +#define IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ + +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" +#include "experimental/rocm/context_wrapper.h" +#include "experimental/rocm/rocm_headers.h" +#include "experimental/rocm/dynamic_symbols.h" +#include "experimental/rocm/tracing.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a ROCM stream command buffer that immediately issues commands against +// the given |stream|. Access to |stream| must be synchronized by the user. +// +// If |block_pool| is non-NULL then the stream command buffer will retain copies +// of input data until reset. If NULL then the caller must ensure the lifetime +// of input data outlives the command buffer. +// +// This command buffer is used to both replay deferred command buffers and +// perform inline execution. When replaying the scratch data required for things +// like buffer updates is retained by the source deferred command buffer and as +// such the |block_pool| and can be NULL to avoid a double copy. +iree_status_t iree_hal_rocm_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_tracing_context_t* tracing_context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hipStream_t stream, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns true if |command_buffer| is a ROCM stream-based command buffer. +bool iree_hal_rocm_stream_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ diff --git a/experimental/split_mlir/iree_compiler_plugin_group.cmake b/experimental/split_mlir/iree_compiler_plugin_group.cmake new file mode 100644 index 000000000000..5ec79cd63392 --- /dev/null +++ b/experimental/split_mlir/iree_compiler_plugin_group.cmake @@ -0,0 +1,7 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src split_mlir) diff --git a/experimental/split_mlir/lit.cfg.py b/experimental/split_mlir/lit.cfg.py new file mode 100644 index 000000000000..eba45391c230 --- /dev/null +++ b/experimental/split_mlir/lit.cfg.py @@ -0,0 +1,41 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Lint for undefined variables is disabled as config is not defined inside this +# file, instead config is injected by way of evaluating runlit.cfg.py from +# runlit.site.cfg.py which in turn is evaluated by lit.py. +# pylint: disable=undefined-variable + +import os +import tempfile + +import lit.formats + +config.name = "IREE" +config.suffixes = [".mlir", ".txt"] +config.test_format = lit.formats.ShTest(execute_external=True) + +# Forward all IREE environment variables, as well as some passthroughs. +# Note: env vars are case-insensitive on Windows, so check matches carefully. +# https://stackoverflow.com/q/7797269 +passthrough_env_vars = [ + # The Vulkan loader uses this + "VK_ICD_FILENAMES", + # WindowsLinkerTool uses these from vcvarsall + "VCTOOLSINSTALLDIR", + "UNIVERSALCRTSDKDIR", + "UCRTVERSION" +] +config.environment.update({ + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars +}) + +# Use the most preferred temp directory. +config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or + os.environ.get("TEST_TMPDIR") or + os.path.join(tempfile.gettempdir(), "lit")) diff --git a/experimental/split_mlir/src/CMakeLists.txt b/experimental/split_mlir/src/CMakeLists.txt new file mode 100644 index 000000000000..da48cd621651 --- /dev/null +++ b/experimental/split_mlir/src/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + defs + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + PUBLIC +) + +# Configures all iree_cc_* targets to take this implicit dep, +# which provides common includes and copts for the tree. +set(IREE_IMPLICIT_DEFS_CC_DEPS iree::experimental::split_mlir::src::defs) + +iree_add_all_subdirs() diff --git a/experimental/split_mlir/src/iree/CMakeLists.txt b/experimental/split_mlir/src/iree/CMakeLists.txt new file mode 100644 index 000000000000..33551b576974 --- /dev/null +++ b/experimental/split_mlir/src/iree/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() diff --git a/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt new file mode 100644 index 000000000000..d59a3df68730 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt @@ -0,0 +1,80 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +iree_cc_library( + NAME + split_mlir_lib + HDRS + "Passes.h" + "Passes.h.inc" + SRCS + "Passes.cpp" + DEPS + ::PassesIncGen + MLIRFuncDialect + MLIRIR + MLIRPass + PUBLIC +) + +iree_cc_library( + NAME + registration + SRCS + "PluginRegistration.cpp" + DEPS + ::split_mlir_lib + MLIRIR + MLIRPass + iree::compiler::PluginAPI + PUBLIC +) + +iree_compiler_register_plugin( + PLUGIN_ID + split_mlir + TARGET + ::registration +) + +iree_pyext_module( + NAME + PyExt + MODULE_NAME _split_mlir + SRCS + "OperationListImpl.h" + "SplitMlirPyExt.cpp" + DEPS + MLIRFuncDialect + MLIRIR + MLIRAsmParser + iree::compiler::Tools::init_passes_and_dialects +) + +iree_py_library( + NAME + split_mlir_py + SRCS + "__init__.py" + "_split_mlir.pyi" + "execution.py" + "iree_execution.py" + "types.py" + DEPS + MLIRPythonModules + ::PyExt +) diff --git a/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h b/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h new file mode 100644 index 000000000000..1224dcd4818a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h @@ -0,0 +1,87 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/split_mlir/Passes.h" +#include "llvm/ADT/SmallSet.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DEF_MARKBISECT +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +namespace { + +void markRangeFirst(Operation& op, OpBuilder& builder) { + op.setAttr("outline_range_first", builder.getUnitAttr()); +} + +void markRangeLast(Operation& op, OpBuilder& builder) { + op.setAttr("outline_range_last", builder.getUnitAttr()); +} + +struct MarkBisectPass : public impl::MarkBisectBase { + using MarkBisectBase::MarkBisectBase; + + LogicalResult initialize(MLIRContext* context) override { + functionsSet.insert(functions.begin(), functions.end()); + return LogicalResult::success(); + } + + void runOnOperation() override { + mlir::func::FuncOp funcOp = getOperation(); + if (!functionsSet.contains(funcOp.getSymName())) { + return; + } + if (funcOp.getBody().getBlocks().size() > 1) { + return signalPassFailure(); + } + Block& entryBlock = funcOp.getBody().front(); + if (entryBlock.getOperations().size() < 3) { + // Degenerate case. Needs at least 1 op for each half + the return op. + return; + } + size_t opsCount = entryBlock.getOperations().size(); + size_t cutOpIndex = (opsCount - 1) / 2; + OpBuilder builder(&getContext()); + // Ranges are inclusive, [first, last]. + auto firstHalfLastOp = entryBlock.begin(); + std::advance(firstHalfLastOp, cutOpIndex - 1); + markRangeFirst(entryBlock.front(), builder); + markRangeLast(*firstHalfLastOp, builder); + auto secondHalfFirstOp = firstHalfLastOp; + std::advance(secondHalfFirstOp, 1); + markRangeFirst(*secondHalfFirstOp, builder); + auto secondHalfLastOp = entryBlock.end(); + // Take operation that is just before the return operation. + std::advance(secondHalfLastOp, -2); + markRangeLast(*secondHalfLastOp, builder); + } + + private: + llvm::SmallSet functionsSet; +}; + +} // namespace + +std::unique_ptr> createMarkBisectPass() { + return std::make_unique(); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h b/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h new file mode 100644 index 000000000000..df2f4180a2d3 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h @@ -0,0 +1,181 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "iree/compiler/Tools/init_dialects.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +using OpId = std::string; +using OpIndex = size_t; +using ResultIndex = size_t; +using ResultId = std::tuple; +using Arguments = std::vector; +using OperationList = std::vector>; + +ResultIndex getResultIndex(OpOperand& operand) { + OpResult opResult = operand.get().dyn_cast(); + if (opResult) { + return opResult.getResultNumber(); + } + + BlockArgument blockArgument = operand.get().dyn_cast(); + assert(blockArgument); + return blockArgument.getArgNumber(); +} + +FailureOr getDefiningOpIndex( + OpOperand& operand, Block& block, + const std::unordered_map& operationInBlockIndexMap) { + Value value = operand.get(); + if (value.isa()) { + return 0; + } + + OpResult opResult = value.dyn_cast(); + if (!opResult) { + operand.getOwner()->emitError( + Twine("Operand ") + std::to_string(operand.getOperandNumber()) + + "is neigher a block argument or a result of an operation"); + return failure(); + } + if (value.getDefiningOp()->getBlock() != &block) { + operand.getOwner()->emitError( + "Can't extract call graph for block that is not isolated from above."); + return failure(); + } + + auto it = operationInBlockIndexMap.find(value.getDefiningOp()); + assert(it != operationInBlockIndexMap.end()); + return it->second; +} + +std::string getOpId(Operation& op) { + func::CallOp callOp = dyn_cast(op); + if (callOp) { + return (Twine("call ") + callOp.getCallee()).str(); + } + + if (isa(op)) { + return "return"; + } + + return op.getName().getStringRef().str(); +} + +FailureOr extractOperationList(Block& block) { + OperationList res; + // Block arguments don't depend on anything. + res.emplace_back(); + // Index inside the block. + std::unordered_map operationInBlockIndexMap; + + for (auto opIt : llvm::enumerate(block)) { + operationInBlockIndexMap.insert({&opIt.value(), opIt.index() + 1}); + OpId id = getOpId(opIt.value()); + Arguments arguments; + for (OpOperand& operand : opIt.value().getOpOperands()) { + FailureOr opIndex = + getDefiningOpIndex(operand, block, operationInBlockIndexMap); + FailureOr resultIndex = getResultIndex(operand); + if (failed(opIndex) || failed(resultIndex)) { + return failure(); + } + arguments.emplace_back(opIndex.value(), resultIndex.value()); + } + res.emplace_back(id, arguments); + } + + return res; +} + +FailureOr> loadMlir(const char* mlirFilePath, + MLIRContext& context) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(mlirFilePath); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + return parseSourceFile(sourceMgr, &context); +} + +func::FuncOp findFunction(Operation* root, StringRef name) { + func::FuncOp res; + root->walk([&](func::FuncOp op) { + if (op.getSymName() == name) { + res = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return res; +} + +FailureOr extractOperationList(ModuleOp moduleOp, + StringRef functionName) { + func::FuncOp funcOp = findFunction(moduleOp.getOperation(), functionName); + Region* region = funcOp.getCallableRegion(); + if (!region) { + funcOp.emitError("No callable region found."); + return failure(); + } + if (region->getBlocks().size() != 1) { + funcOp.emitError("Blocks count must be exactly 1."); + return failure(); + } + return extractOperationList(region->front()); +} + +FailureOr extractOperationList(const char* mlirFilePath, + StringRef functionName, + MLIRContext& context) { + auto moduleOp = loadMlir(mlirFilePath, context); + if (failed(moduleOp)) { + return failure(); + } + + return extractOperationList(moduleOp->get(), functionName); +} + +std::unique_ptr makeMlirContext() { + mlir::DialectRegistry registry; + mlir::iree_compiler::registerAllDialects(registry); + auto context = std::make_unique(registry); + return context; +} + +FailureOr extractOperationList(const char* mlirFilePath, + StringRef functionName) { + auto context = makeMlirContext(); + return extractOperationList(mlirFilePath, functionName, *context); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h b/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h new file mode 100644 index 000000000000..783520e7fdb9 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h @@ -0,0 +1,299 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include + +#include "iree/split_mlir/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DEF_OUTLINEFUNCTIONS +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +namespace { + +// Collect all operation ranges that are marked for outlining. +// The begining of a range is marked with the outline_range_first attribute. +// The last operation of a range is marked with the outline_range_last attribue. +// Example: +// %0 = arith.addi %arg0, %arg1 {outline_range_first} : i32 +// %1 = arith.addi %arg2, %arg3 : i32 +// %2 = arith.muli %arg3, %arg4 {outline_range_last} : i32 +// The outline range will consist of the 3 operations. +LogicalResult getOutlineOpRanges( + Block& block, SmallVector, 4>& res) { + bool isInOutliningRange = false; + Block::iterator rangeBegin; + for (Block::iterator opIt = block.begin(); opIt != block.end(); ++opIt) { + if (opIt->hasAttr("outline_range_first")) { + if (isInOutliningRange) { + return LogicalResult::failure(); + } + isInOutliningRange = true; + rangeBegin = opIt; + } + + if (opIt->hasAttr("outline_range_last")) { + if (!isInOutliningRange) { + return LogicalResult::failure(); + } + isInOutliningRange = false; + res.emplace_back(rangeBegin, std::next(opIt)); + } + } + if (isInOutliningRange) { + // No matching closing marker outline_range_last. + return LogicalResult::failure(); + } + + return LogicalResult::success(); +} + +// Return all values that are an operand of some of the given ops that are +// produced by other ops. Also return all values that are a result of some of +// the given ops and have uses outside the ops range. +std::pair, SmallVector> +getOperandsAndResultsForIsolation(iterator_range opRange, + const SmallPtrSet& opsSet) { + SmallVector operands; + SmallVector results; + SmallPtrSet operandsSet; + SmallPtrSet resultsSet; + for (Operation& op : opRange) { + for (Value operand : op.getOperands()) { + if (!opsSet.contains(operand.getDefiningOp())) { + auto insertionResult = operandsSet.insert(operand); + if (insertionResult.second) { + operands.push_back(operand); + } + } + } + for (OpResult result : op.getResults()) { + for (OpOperand operand : result.getUsers()) { + if (!opsSet.contains(operand.getOwner())) { + auto insertionResult = resultsSet.insert(result); + if (insertionResult.second) { + results.push_back(result); + } + break; + } + } + } + } + return {operands, results}; +} + +template +void replaceValueUsesWithNewBlockArguments(ValueIt valuesBegin, + ValueIt valuesEnd, Block& block) { + for (ValueIt valIt = valuesBegin; valIt != valuesEnd; ++valIt) { + block.addArgument(valIt->getType(), valIt->getLoc()); + BlockArgument& blockArg = block.getArguments().back(); + valIt->replaceUsesWithIf(blockArg, [&block](OpOperand& operand) { + return operand.getOwner()->getBlock() == █ + }); + } +} + +void addBlockReturn(Block& block, ValueRange operands, OpBuilder& builder) { + func::ReturnOp returnOp = + builder.create(builder.getUnknownLoc(), operands); + block.push_back(returnOp); +} + +void moveOpsIntoBlock(iterator_range opRange, Block& block) { + // Put ops into another container because opRange will be invalidated during + // removal. + SmallVector ops; + std::transform(opRange.begin(), opRange.end(), std::back_inserter(ops), + [](Operation& op) { return &op; }); + for (Operation* op : ops) { + op->moveBefore(&block, block.end()); + } +} + +void moveBlock(Region& srcRegion, Region& destRegion, + Region::iterator srcBlockIt, Region::iterator destBlockIt) { + Block* block = srcRegion.getBlocks().remove(srcBlockIt); + destRegion.getBlocks().insert(destBlockIt, block); +} + +bool isAncestorOfBlock(Operation* op, Block* block) { + // Walk up the operation hierarchy and check each block. + while (op != nullptr) { + if (op->getBlock() == block) { + return true; + } + op = op->getParentOp(); + } + return false; +} + +template +void substititeUses(OriginalOpResultsIt originalBegin, + OriginalOpResultsIt originalEnd, NewOpResultsIt newBegin, + NewOpResultsIt newEnd, Block& excludedBlock) { + assert(std::distance(originalBegin, originalEnd) == + std::distance(newBegin, newEnd)); + auto newIt = newBegin; + for (auto originalIt = originalBegin; originalIt != originalEnd; + ++originalIt, ++newIt) { + originalIt->replaceUsesWithIf(*newIt, [&excludedBlock](OpOperand& operand) { + return !isAncestorOfBlock(operand.getOwner(), &excludedBlock); + }); + } +} + +// All operations in the range `opRange` are moved into a new function with name +// `name`. The resulting function is put inside `moduleOp` and is properly +// isolated from above. This does not insert a call to the new function in place +// of the moved operations. +func::FuncOp createFunctionFromOps(iterator_range opRange, + StringRef name, ModuleOp moduleOp, + SmallVector& rangeOperands, + SmallVector& rangeResults, + OpBuilder& builder) { + Region& region = *opRange.begin()->getParentRegion(); + Block& dstBlock = region.emplaceBlock(); + moveOpsIntoBlock(opRange, dstBlock); + replaceValueUsesWithNewBlockArguments(rangeOperands.begin(), + rangeOperands.end(), dstBlock); + addBlockReturn(dstBlock, + ArrayRef(rangeResults.begin(), rangeResults.end()), + builder); + func::FuncOp funcOp = builder.create( + builder.getUnknownLoc(), name, + FunctionType::get(builder.getContext(), dstBlock.getArgumentTypes(), + dstBlock.back().getOperandTypes())); + moduleOp.getBodyRegion().getBlocks().front().push_back(funcOp); + moveBlock(region, funcOp.getBody(), std::prev(region.end()), + funcOp.getBody().end()); + + return funcOp; +} + +void createCall(func::FuncOp funcOp, Block& block, Block::iterator pos, + SmallVector& rangeOperands, + SmallVector& rangeResults, OpBuilder& builder) { + func::CallOp callOp = builder.create( + builder.getUnknownLoc(), funcOp, + ArrayRef(rangeOperands.begin(), rangeOperands.end())); + block.getOperations().insert(pos, callOp); + substititeUses(rangeResults.begin(), rangeResults.end(), + callOp.getResults().begin(), callOp.getResults().end(), + funcOp.getBody().back()); +} + +std::optional outlineOpRange( + iterator_range opRange, StringRef name, ModuleOp moduleOp, + OpBuilder& builder) { + if (opRange.empty()) { + return std::nullopt; + } + + SmallPtrSet opsSet; + for (Operation& op : opRange) { + opsSet.insert(&op); + } + SmallVector rangeOperands; + SmallVector rangeResults; + std::tie(rangeOperands, rangeResults) = + getOperandsAndResultsForIsolation(opRange, opsSet); + Block& srcBlock = *opRange.begin()->getBlock(); + + func::FuncOp funcOp = createFunctionFromOps( + opRange, name, moduleOp, rangeOperands, rangeResults, builder); + createCall(funcOp, srcBlock, opRange.end(), rangeOperands, rangeResults, + builder); + + return funcOp; +} + +std::string getOutlinedFuncName(StringRef prefix, int blockIndex, + int outlineRangeIndex) { + return (Twine(prefix) + "_outline_" + Twine(blockIndex) + "_" + + Twine(outlineRangeIndex)) + .str(); +} + +void removeOutlineMarkers(iterator_range opRange) { + if (opRange.empty()) { + return; + } + opRange.begin()->removeAttr("outline_range_first"); + std::prev(opRange.end())->removeAttr("outline_range_last"); +} + +// Each marked operation range in `funcOp` is outlined into a new function. +// A call to the new function is inserted in place of the outlined operations. +LogicalResult outlineOpRanges(func::FuncOp funcOp, ModuleOp moduleOp, + OpBuilder& builder) { + Region& funcBody = funcOp.getFunctionBody(); + SmallVector, 4> outlineRanges; + for (auto blockIt : llvm::enumerate(funcBody.getBlocks())) { + outlineRanges.clear(); + if (failed(getOutlineOpRanges(blockIt.value(), outlineRanges))) { + return LogicalResult::failure(); + } + for (auto rangeIt : llvm::enumerate(outlineRanges)) { + removeOutlineMarkers(rangeIt.value()); + std::string name = getOutlinedFuncName(funcOp.getSymName(), + blockIt.index(), rangeIt.index()); + outlineOpRange(rangeIt.value(), name, moduleOp, builder); + } + } + + return LogicalResult::success(); +} + +struct OutlineFunctionsPass + : public impl::OutlineFunctionsBase { + using OutlineFunctionsBase::OutlineFunctionsBase; + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + Block& moduleBlock = *moduleOp.getBody(); + OpBuilder builder(&getContext()); + // Get all functions since we are going to insert new ones + // that we don't want to iterate over. + SmallVector funcOps( + moduleBlock.getOps().begin(), + moduleBlock.getOps().end()); + for (func::FuncOp op : funcOps) { + if (failed(outlineOpRanges(op, moduleOp, builder))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> createOutlineFunctionsPass() { + return std::make_unique(); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.cpp b/experimental/split_mlir/src/iree/split_mlir/Passes.cpp new file mode 100644 index 000000000000..50fd44ad024d --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.cpp @@ -0,0 +1,8 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/split_mlir/MarkBisectPassImpl.h" +#include "iree/split_mlir/OutlineFunctionsPassImpl.h" diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.h b/experimental/split_mlir/src/iree/split_mlir/Passes.h new file mode 100644 index 000000000000..039d00727365 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.h @@ -0,0 +1,37 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ +#define IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class ModuleOp; +namespace func { +class FuncOp; +} // namespace func + +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DECL +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +std::unique_ptr> createOutlineFunctionsPass(); +std::unique_ptr> createMarkBisectPass(); + +#define GEN_PASS_REGISTRATION +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +} // namespace split_mlir +} // namespace iree +} // namespace mlir + +#endif // IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.td b/experimental/split_mlir/src/iree/split_mlir/Passes.td new file mode 100644 index 000000000000..f07f37dfd5a6 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.td @@ -0,0 +1,79 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_SPLIT_MLIR_TRANSFORM_PASSES +#define IREE_SPLIT_MLIR_TRANSFORM_PASSES + +include "mlir/Pass/PassBase.td" + +def OutlineFunctions : + Pass<"iree-outline-functions", "mlir::ModuleOp"> { + let summary = "Outline operations in separate function(s)."; + let description = [{ + Marked operation ranges in a function block are outlined/moved into new functions. + In place of an outlined operations range is inserted a call to the new function. + The resulting function is equivalent to the original. + The ranges for outlining must be marked with the attributes + `outline_range_first`, `outline_range_last`. + + Example: + ```mlir + func.func @f(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 {outline_range_first} : i32 + %1 = arith.muli %0, %0 : i32 + %2 = arith.muli %1, %1 {outline_range_last} : i32 + %3 = arith.addi %2, %2 : i32 + return %3 : i32 + } + ``` + + The above MLIR will be transformed to: + ```mlir + func.func @f(%arg0: i32, %arg1: i32) -> i32 { + %0 = call @f_outline_0_0(%arg0, %arg1) : (i32, i32) -> i32 + %1 = arith.addi %0, %0 : i32 + return %1 : i32 + } + func.func @f_outline_0_0(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 : i32 + %1 = arith.muli %0, %0 : i32 + %2 = arith.muli %1, %1 : i32 + return %2 : i32 + } + ``` + + The pass will fail if there is branching to other function blocks + inside a marked operation range. + }]; + let constructor = "mlir::iree::split_mlir::createOutlineFunctionsPass()"; + let dependentDialects = ["mlir::func::FuncDialect"]; +} + +def MarkBisect : Pass<"iree-mark-bisect", "mlir::func::FuncOp"> { + let summary = "Mark operations in function(s) for outlining with bisect strategy."; + let description = [{ + Each function's entry block is bisected, + such that each piece has balanced number of ops. + The two pieces are marked with attributes `outline_range_first` and + `outline_range_last`. These markings surve as input to the `OutlineFunctions` pass. + + Example: + ```bash + iree-opt \ + --iree-plugin=split_mlir \ + --pass-pipeline="builtin.module(func.func(iree-mark-bisect{functions=f,g}))" + my.mlir + ``` + + }]; + let constructor = "mlir::iree::split_mlir::createMarkBisectPass()"; + let options = [ + ListOption<"functions", "functions", "std::string", "List of functions to bisect."> + ]; + let dependentDialects = ["mlir::func::FuncDialect"]; +} + +#endif // IREE_SPLIT_MLIR_TRANSFORM_PASSES diff --git a/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp b/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp new file mode 100644 index 000000000000..762bfe327b7d --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp @@ -0,0 +1,32 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/PluginAPI/Client.h" +#include "iree/split_mlir/Passes.h" + + +using namespace mlir; +using namespace mlir::iree_compiler; + +namespace { + +struct SplitMlirOptions { + void bindOptions(OptionsBinder &binder) {} +}; + +struct SplitMlirSession : public PluginSession { + static void registerPasses() { + iree::split_mlir::registerPasses(); + } +}; +} // namespace + +IREE_DEFINE_COMPILER_OPTION_FLAGS(SplitMlirOptions); + +extern "C" bool iree_register_compiler_plugin_split_mlir(PluginRegistrar *registrar) { + registrar->registerPlugin("split_mlir"); + return true; +} diff --git a/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp b/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp new file mode 100644 index 000000000000..d1b5560fe660 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp @@ -0,0 +1,37 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include + +#include "OperationListImpl.h" + +namespace py = pybind11; + +namespace mlir { +namespace iree { +namespace split_mlir { + +PYBIND11_MODULE(_split_mlir, m) { + m.doc() = "Split MLIR C++ extension"; + + m.def( + "extract_operation_list", + [](const std::string& mlirFilePath, const std::string& functionName) { + auto res = extractOperationList(mlirFilePath.c_str(), functionName); + if (failed(res)) { + throw std::runtime_error(""); + } + return res.value(); + }, + py::arg("mlir_file_path"), py::arg("function_name")); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/__init__.py b/experimental/split_mlir/src/iree/split_mlir/__init__.py new file mode 100644 index 000000000000..f66106cd2b4a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._split_mlir import extract_operation_list +from .execution import execute_operation_list +from .iree_execution import IreeExecutor, execute_mlir_with_iree + +__all__ = [ + "execute_operation_list", "execute_mlir_with_iree", + "extract_operation_list", "IreeExecutor" +] diff --git a/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi b/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi new file mode 100644 index 000000000000..f21ca121c93e --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi @@ -0,0 +1,12 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .types import OperationList + + +def extract_operation_list(mlir_file_path: str, + function_name: str) -> OperationList: + ... diff --git a/experimental/split_mlir/src/iree/split_mlir/execution.py b/experimental/split_mlir/src/iree/split_mlir/execution.py new file mode 100644 index 000000000000..ddf8988fb39e --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/execution.py @@ -0,0 +1,42 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Optional +from .types import OpArguments, Tensor, ExecuteOp, OperationList + + +def collect_arguments(arguments: OpArguments, + results: List[List[Tensor]]) -> List[Tensor]: + return [results[arg[0]][arg[1]] for arg in arguments] # type: ignore + + +def execute_operation_list( + input: List[Tensor], + operation_list: OperationList, + execute_op: ExecuteOp, + override_results: Optional[List[List[Tensor]]] = None +) -> List[List[Tensor]]: + """Algorithm to execute a call list. + + Parameters + ---------- + input : Input of the graph. + execute_op : Callable that executes an operation from the graph. + override_results : When execting operations override arguments with this values, + instead of using the computed resuts from previous functions. + + Returns + ------- + All results from all operations is the graph are in the same order + as they appear in `operation_list`. `input` is prepened to the result. + """ + results = [input] + for op in operation_list[1:]: + arguments = collect_arguments( + arguments=op[1], + results=results if override_results is None else override_results) + results.append(execute_op(op[0], arguments)) + return results diff --git a/experimental/split_mlir/src/iree/split_mlir/iree_execution.py b/experimental/split_mlir/src/iree/split_mlir/iree_execution.py new file mode 100644 index 000000000000..cfcfe92a3a90 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/iree_execution.py @@ -0,0 +1,122 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Callable, Tuple, Dict, Optional, Any +from .types import Tensor +import iree.runtime +from iree.runtime import VmModule, HalDevice, load_vm_module +from .execution import execute_operation_list +from tempfile import TemporaryDirectory +from iree.compiler.tools import compile_file +import os +from pathlib import Path +from collections import namedtuple +from ._split_mlir import extract_operation_list +from numbers import Number + +VmfbFilePath = str +MlirFilePath = str +FunctionName = str + + +class IreeExecutor: + """Executor for IREE that implements the `.types.ExecuteOp` interface.""" + + def __init__(self, device: HalDevice, + resolve_function: Callable[[FunctionName], Tuple[VmfbFilePath, + FunctionName]]): + """ + Parameters + ---------- + resolve_function : Resolves a function name that is called in the entry MLIR to + the vmfb file where it can be found under another name. + """ + self.device = device + self.resolve_function = resolve_function + + def __call__(self, op_id: str, operands: List[Tensor]) -> List[Tensor]: + if op_id.startswith("call "): + function_name = op_id.split(" ", 2)[1] + vmfb_file_path, vmfb_function_name = self.resolve_function(function_name) + config = iree.runtime.Config(device=self.device) + with open(vmfb_file_path, "rb") as f: + vm_flatbuffer = f.read() + vm_module_fb_bytes = VmModule.from_flatbuffer(config.vm_instance, + vm_flatbuffer) + vm_module = load_vm_module(vm_module_fb_bytes, config) + res = getattr(vm_module, vmfb_function_name)(*operands) + if isinstance(res, (iree.runtime.DeviceArray, Number)): + res = [res] + return res + if op_id == "return": + return operands + raise RuntimeError(f"Invalid op_id \"{op_id}\".") + + +def mlir_to_vmfb_file_path(mlir_file_path: str) -> str: + return f"{Path(mlir_file_path).stem}.vmfb" + + +def execute_mlir_with_iree(input: List[Tensor], + mlir_path_function_pairs: List[Tuple[MlirFilePath, + FunctionName]], + compile_kwargs: Dict[str, Any], + device: HalDevice, + override_results: Optional[List[ + List[Tensor]]] = None, + artifact_dir: Optional[str] = None) -> List[Tensor]: + """Executes an MLIR program that is split accorss multiple MLIR files. + Parameters + ---------- + mlir_path_function_pairs : List of MLIR files and the function they contain. + The first element is the entry MLIR and function. + It is expected that a name of function called in the entry function correspnd + to an MLIR file with the same name without file name extension. + compile_kwargs : Compile arguments to pass to iree.compiler.tools.compile_file. + artifact_dir : Where to put temporary files. + Defaults to creating a unique temporary directory that is deleted on completion. + + See: `execute_operation_list` + """ + if artifact_dir is None: + with TemporaryDirectory() as temp_dir: + return execute_mlir_with_iree( + input=input, + mlir_path_function_pairs=mlir_path_function_pairs, + override_results=override_results, + compile_kwargs=compile_kwargs, + device=device, + artifact_dir=temp_dir) + + entry_mlir_file_path = mlir_path_function_pairs[0][0] + entry_function_name = mlir_path_function_pairs[0][1] + FunctionDescription = namedtuple( + "FunctionDescription", + ["mlir_file_path", "vmfb_file_path", "function_name"]) + function_map = { + Path(Path(p[0]).name).stem: FunctionDescription( + p[0], os.path.join(artifact_dir, mlir_to_vmfb_file_path(p[0])), p[1]) + for p in mlir_path_function_pairs + } + for i in range(1, len(mlir_path_function_pairs)): + function_description = function_map[Path( + Path(mlir_path_function_pairs[i][0]).name).stem] + compile_file(function_description.mlir_file_path, + output_file=function_description.vmfb_file_path, + **compile_kwargs) + + def resolve_function( + function_name: FunctionName) -> Tuple[VmfbFilePath, FunctionName]: + func_desc = function_map[function_name] + return (func_desc.vmfb_file_path, func_desc.function_name) + + executor = IreeExecutor(device=device, resolve_function=resolve_function) + operation_list = extract_operation_list(mlir_file_path=entry_mlir_file_path, + function_name=entry_function_name) + return execute_operation_list(operation_list=operation_list, + execute_op=executor, + input=input, + override_results=override_results) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt new file mode 100644 index 000000000000..27c8e39e0be0 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "mark_bisect.mlir" + TOOLS + FileCheck + iree-opt +) + +iree_lit_test_suite( + NAME + lit + SRCS + "function_outlining.mlir" + TOOLS + FileCheck + iree-opt +) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt new file mode 100644 index 000000000000..77611fc4278a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_local_py_test( + NAME + execution_test + SRC + execution_test.py + PACKAGE_DIRS + "${IREE_BINARY_DIR}/compiler/bindings/python" + "${IREE_BINARY_DIR}/runtime/bindings/python" + "${IREE_BINARY_DIR}/compiler/plugins/split_mlir" +) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir new file mode 100644 index 000000000000..ea6cd8b5f4e5 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir @@ -0,0 +1,14 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func nested @f1(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) +func.func nested @f2(%arg0: tensor<1xf32>) -> tensor<1xf32> + +func.func @caller(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { + %0:2 = call @f1(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) + %1 = call @f2(%0#0) : (tensor<1xf32>) -> tensor<1xf32> + return %arg1, %1 : tensor<1xf32>, tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py b/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py new file mode 100644 index 000000000000..98e5b766eacd --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from iree.split_mlir import extract_operation_list, execute_mlir_with_iree +from typing import List, Any +import os +import iree.runtime +import numpy as np + + +def assert_nested_array_equals(a: List[Any], b: List[Any]): + assert a == b, f"{a} != {b}" + + +class ExecutionTest(unittest.TestCase): + + def test_extract_operation_list(self): + expected_operation_list = [ + ("", []), + ("call f1", [(0, 0), (0, 0)]), + ("call f2", [(1, 0)]), + ("return", [(0, 1), (2, 0)]), + ] + operation_list = extract_operation_list(mlir_file_path=os.path.join( + os.path.dirname(__file__), "entry.mlir"), + function_name="caller") + assert_nested_array_equals(expected_operation_list, operation_list) + + def test_mlir_execution(self): + mlir_path_function_pairs = [ + (os.path.join(os.path.dirname(__file__), "entry.mlir"), "caller"), + (os.path.join(os.path.dirname(__file__), "f1.mlir"), "f1"), + (os.path.join(os.path.dirname(__file__), "f2.mlir"), "main"), + ] + compile_kwargs = { + "target_backends": ["llvm-cpu"], + } + device = iree.runtime.get_device("local-task") + input = [np.array([1], dtype=np.float32), np.array([2], dtype=np.float32)] + results = execute_mlir_with_iree( + input=input, + mlir_path_function_pairs=mlir_path_function_pairs, + compile_kwargs=compile_kwargs, + device=device) + expected_output = [ + np.array([2], dtype=np.float32), + np.array([4], dtype=np.float32) + ] + assert_nested_array_equals(results[-1], expected_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir new file mode 100644 index 000000000000..8ca9609123a7 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir @@ -0,0 +1,10 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func @f1(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { + %0 = arith.addf %arg0, %arg1: tensor<1xf32> + return %0, %0 : tensor<1xf32>, tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir new file mode 100644 index 000000000000..6b0bc790fdbb --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir @@ -0,0 +1,10 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func @main(%arg0: tensor<1xf32>) -> tensor<1xf32> { + %0 = arith.addf %arg0, %arg0 : tensor<1xf32> + return %0 : tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir b/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir new file mode 100644 index 000000000000..cb8df4596192 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir @@ -0,0 +1,102 @@ +// RUN: iree-opt \ +// RUN: --split-input-file \ +// RUN: --iree-plugin=split_mlir \ +// RUN: --pass-pipeline="builtin.module(iree-outline-functions)" %s \ +// RUN: | FileCheck --dump-input-context=100 %s + +// Outline op that does not take any arguments and is not used anywhere. +// CHECK-LABEL: func.func @no_args_and_result +func.func @no_args_and_result() { +// CHECK: call @no_args_and_result_outline_0_0() : () -> () + %cts1 = mhlo.constant {outline_range_first, outline_range_last} dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: {{return$}} + return +} +// CHECK-LABEL: func.func @no_args_and_result_outline_0_0() +// CHECK: mhlo.constant dense<{{.+}}> : tensor<2xf32> +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last +// CHECK-NEXT: {{return$}} + +// ----- + +// Outline an op that takes one argument and has one result that is used. +// CHECK-LABEL: func.func @one_arg_and_one_result +// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>) -> tensor<2xf32> +func.func @one_arg_and_one_result(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NEXT: [[RES0:%.+]] = call @one_arg_and_one_result_outline_0_0([[ARG0]]) + %res = mhlo.cosine %arg0 {outline_range_first, outline_range_last} : tensor<2xf32> +// CHECK-NEXT: return [[RES0]] : tensor<2xf32> + return %res : tensor<2xf32> +} +// CHECK-LABEL: func.func @one_arg_and_one_result_outline_0_0 +// CHECK-SAME: ([[ARG1:%.+]]: tensor<2xf32>) -> tensor<2xf32> +// CHECK-NEXT: [[RES1:%.+]] = mhlo.cosine [[ARG1]] : tensor<2xf32> +// CHECK-NEXT: return [[RES1]] : tensor<2xf32> + +// ----- + +// Multiple ops in a range with multiple arguments and results. +// CHECK-LABEL: func.func @multiple_ops +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @multiple_ops(%arg0: i32, %arg1: i32) -> (i32, i32) { +// CHECK-NEXT: [[RES0:%.+]]:2 = call @multiple_ops_outline_0_0([[ARG0]], [[ARG1]]) : (i32, i32) -> (i32, i32) + %add = arith.addi %arg0, %arg0 {outline_range_first} : i32 + %mul = arith.muli %add, %arg1 {outline_range_last} : i32 +// CHECK-NEXT: return [[RES0]]#0, [[RES0]]#1 : i32, i32 + return %add, %mul : i32, i32 +} +// CHECK-LABEL: func.func @multiple_ops_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32, [[ARG11:%.+]]: i32) -> (i32, i32) +// CHECK-NEXT: [[ADD:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: [[MUL:%.+]] = arith.muli [[ADD]], [[ARG11]] : i32 +// CHECK-NEXT: return [[ADD]], [[MUL]] : i32, i32 + +// ----- + +// Outline multiple ranges in the same function. +// CHECK-LABEL: func.func @multiple_ranges_in_same_func +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @multiple_ranges_in_same_func(%arg0: i32, %arg1: i32) -> (i32, i32) { +// CHECK-NEXT: [[ADD:%.+]] = call @multiple_ranges_in_same_func_outline_0_0([[ARG0]]) : (i32) -> i32 + %add = arith.addi %arg0, %arg0 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: [[MUL:%.+]] = call @multiple_ranges_in_same_func_outline_0_1([[ADD]], [[ARG1]]) : (i32, i32) -> i32 + %mul = arith.muli %add, %arg1 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: return [[ADD]], [[MUL]] : i32, i32 + return %add, %mul : i32, i32 +} +// CHECK-LABEL: func.func @multiple_ranges_in_same_func_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32) -> i32 +// CHECK-NEXT: [[ADD1:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: return [[ADD1]] : i32 +// CHECK-LABEL: func.func @multiple_ranges_in_same_func_outline_0_1 +// CHECK-SAME: ([[ARG20:%.+]]: i32, [[ARG21:%.+]]: i32) -> i32 +// CHECK-NEXT: [[MUL2:%.+]] = arith.muli [[ARG20]], [[ARG21]] : i32 +// CHECK-NEXT: return [[MUL2]] : i32 + +// ----- + +// Outline multiple ranges in different blocks. +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 +func.func @multiple_ranges_in_different_blocks(%arg0: i32, %arg1: i32) -> i32 { +// CHECK-NEXT: [[ADD:%.+]] = call @multiple_ranges_in_different_blocks_outline_0_0([[ARG0]]) : (i32) -> i32 + %add = arith.addi %arg0, %arg0 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: cf.br ^bb1([[ARG1]] : i32) + cf.br ^bb1(%arg1 : i32) +// CHECK-NEXT: ^bb1 +// CHECK-SAME: ([[ARG2:%.+]]: i32) +^bb1 (%arg2: i32): +// CHECK-NEXT: [[MUL:%.+]] = call @multiple_ranges_in_different_blocks_outline_1_0([[ADD]], [[ARG2]]) : (i32, i32) -> i32 + %mul = arith.muli %add, %arg2 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: return [[MUL]] : i32 + return %mul : i32 +} +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32) -> i32 +// CHECK-NEXT: [[ADD1:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: return [[ADD1]] : i32 +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks_outline_1_0 +// CHECK-SAME: ([[ARG20:%.+]]: i32, [[ARG21:%.+]]: i32) -> i32 +// CHECK-NEXT: [[MUL2:%.+]] = arith.muli [[ARG20]], [[ARG21]] : i32 +// CHECK-NEXT: return [[MUL2]] : i32 diff --git a/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir b/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir new file mode 100644 index 000000000000..3a4cb72424a7 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir @@ -0,0 +1,68 @@ +// RUN: iree-opt \ +// RUN: --split-input-file \ +// RUN: --iree-plugin=split_mlir \ +// RUN: --pass-pipeline="builtin.module(func.func(iree-mark-bisect{functions=two_ops,too_few_ops,multiple_ops}))" %s \ +// RUN: | FileCheck %s + +// Each operation is marked as separate range. +// CHECK-LABEL: func.func @two_ops +func.func @two_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: mhlo.constant +// CHECK-DAG: outline_range_first +// CHECK-DAG: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: mhlo.add +// CHECK-DAG: outline_range_first +// CHECK-DAG: outline_range_last + %res = mhlo.add %arg0, %cts1 : tensor<2xf32> + return %res : tensor<2xf32> +} + +// ----- + +// Degenerate case with too few ops should not mark enything. +// CHECK-LABEL: func.func @too_few_ops +func.func @too_few_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: mhlo.constant +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: return +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + return %cts1 : tensor<2xf32> +} + +// ----- + +// Multiple ops per range. +// CHECK-LABEL: func.func @multiple_ops +func.func @multiple_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: outline_range_first +// CHECK-SAME: dense<1.000000e+00> + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_last +// CHECK-SAME: dense<2.000000e+00> + %cts2 = mhlo.constant dense<2.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_first +// CHECK-SAME: dense<3.000000e+00> + %cts3 = mhlo.constant dense<3.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_last +// CHECK-SAME: dense<4.000000e+00> + %cts4 = mhlo.constant dense<4.000000e+00> : tensor<2xf32> +// CHECK-NEXT: return + return %cts1 : tensor<2xf32> +} + +// ----- + +// Non-listed functions should not be marked. +// CHECK-LABEL: func.func @function_not_to_mark +func.func @function_not_to_mark(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> + %res = mhlo.add %arg0, %cts1 : tensor<2xf32> +// CHECK: return + return %res : tensor<2xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/types.py b/experimental/split_mlir/src/iree/split_mlir/types.py new file mode 100644 index 000000000000..418bc207ab91 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/types.py @@ -0,0 +1,19 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, TypeVar, Callable, Tuple +from numbers import Integral + +Tensor = TypeVar("Tensor") +OpId = TypeVar("OpId") +ExecuteOp = Callable[[OpId, List[Tensor]], List[Tensor]] +OperationIndex = Integral +ResultIndex = Integral +"""Description of the dependencies of an operation.""" +OpArguments = List[Tuple[OperationIndex, ResultIndex]] +Operation = Tuple[OpId, OpArguments] +"""Describes a dependency graph of operations.""" +OperationList = List[Operation] diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h index 9627486a62cc..54f9e947d36d 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h @@ -115,7 +115,7 @@ createTileAndDecomposeWinogradTransformPass(); // Creates a pass to convert linalg convolution ops into a sequence of // linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd // tranformation. -std::unique_ptr createConvertConv2DToWinogradPass(); +std::unique_ptr createConvertConv2DToWinogradPass(bool forceWinograd = false); // Transform dialect version of tile and decompose attention wrapper. void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp, diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h index 77dbb09135b2..62d93d14e5aa 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h @@ -82,6 +82,62 @@ const float A_6x6_3x3[] = { // clang-format on +//===----------------------------------------------------------------------===// +// Output tile size = 4, Kernel size = 3 +//===----------------------------------------------------------------------===// +// These constants were obtained from this paper: +// +// Lavin, A. et al (2016) Fast Algorithms for Convolution Neural Networks. +// https://openaccess.thecvf.com/content_cvpr_2016/papers/Lavin_Fast_Algorithms_for_CVPR_2016_paper.pdf +// + +// clang-format off + +const float BT_4x4_3x3[] = { + 4, 0, -5, 0, 1, 0, + 0, -4, -4, 1, 1, 0, + 0, 4, -4, -1, 1, 0, + 0, -2, -1, 2, 1, 0, + 0, 2, -1, -2, 1, 0, + 0, 4, 0, -5, 0, 1 +}; + +const float B_4x4_3x3[] = { + 4, 0, 0, 0, 0, 0, + 0, -4, 4, -2, 2, 4, + -5, -4, -4, -1, -1, 0, + 0, 1, -1, 2, -2, -5, + 1, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 1 +}; + +const float G_4x4_3x3[] = { + 1./4., 0, 0, + -1./6., -1./6., -1./6., + -1./6., 1./6., -1./6., + 1./24., 1./12., 1./6., + 1./24., -1./12., 1./6., + 0, 0, 1 +}; + +const float AT_4x4_3x3[] = { + 1, 1, 1, 1, 1, 0, + 0, 1, -1, 2, -2, 0, + 0, 1, 1, 4, 4, 0, + 0, 1, -1, 8, -8, 1 +}; + +const float A_4x4_3x3[] = { + 1, 0, 0, 0, + 1, 1, 1, 1, + 1, -1, 1, -1, + 1, 2, 4, 8, + 1, -2, 4, -8, + 0, 0, 0, 1 +}; + +// clang-format on + } // namespace Winograd } // namespace LinalgExt } // namespace IREE diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp index 0680be4d8e8e..3a994f1b048f 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp @@ -24,12 +24,16 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" +#include +#include namespace mlir { namespace iree_compiler { namespace IREE { namespace LinalgExt { +static const char winogradAttr[] = "iree_winograd_conv"; + static inline int index(int y, int x, int dimy, int dimx) { return (x + dimx * y); } @@ -45,7 +49,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { // TODO: Make this a user-settable parameter once we have support // for more tile sizes -static constexpr int64_t outputTileSize = 6; +static constexpr int64_t outputTileSize = 4; /// This function computes the Winograd filter transform when /// the filter is known to be a constant. Specifically, this @@ -70,30 +74,50 @@ foldFilterTransform(ArrayRef shape, int64_t inputTileSize, const int &kw = isNchw ? shape[3] : shape[1]; const int &ic = isNchw ? shape[1] : shape[2]; const int &oc = isNchw ? shape[0] : shape[3]; + //printf("Folding filter with kh = %d, kw = %d, ic = %d, oc = %d\n", kh, kw, ic, oc); const int64_t numElements = inputTileSize * inputTileSize * ic * oc; + float *alloc{nullptr}; + if (!isSplat) { + alloc = (float *) malloc(kh * kw * ic * oc * sizeof(float)); + for (int d2 = 0; d2 < ic; d2++) { + for (int d3 = 0; d3 < oc; d3++) { + for (int d4 = 0; d4 < kernelSize; d4++) { + for (int d5 = 0; d5 < kernelSize; d5++) { + int idx; + if (!isNchw) { + idx = index(d4, d5, d2, d3, kh, kw, ic, oc); + } else { + idx = index(d3, d2, d4, d5, oc, ic, kh, kw); + } + alloc[idx] = input[idx].convertToFloat(); + } + } + } + } + } SmallVector output(numElements, APFloat(0.0f)); for (int d0 = 0; d0 < inputTileSize; d0++) { for (int d1 = 0; d1 < inputTileSize; d1++) { for (int d2 = 0; d2 < ic; d2++) { for (int d3 = 0; d3 < oc; d3++) { - APFloat accum(0.0f); + float accum(0.0f); for (int d4 = 0; d4 < kernelSize; d4++) { for (int d5 = 0; d5 < kernelSize; d5++) { - APFloat ival(splatValue); + float ival{splatValue}; if (!isSplat) { if (!isNchw) { - ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)]; + ival = alloc[index(d4, d5, d2, d3, kh, kw, ic, oc)]; } else { - ival = input[index(d3, d2, d4, d5, oc, ic, kh, kw)]; + ival = alloc[index(d3, d2, d4, d5, oc, ic, kh, kw)]; } } int idx0 = index(d0, d4, inputTileSize, kernelSize); int idx1 = index(d1, d5, inputTileSize, kernelSize); - accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]); + accum = accum + G[idx0] * ival * G[idx1]; } } int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc); - output[odx] = accum; + output[odx] = APFloat(accum); if (floatType.isF16()) { bool losesInfo; output[odx].convert(APFloat::IEEEhalf(), @@ -103,6 +127,7 @@ foldFilterTransform(ArrayRef shape, int64_t inputTileSize, } } } + if (alloc) free(alloc); return DenseElementsAttr::get(outputType, output); } @@ -134,10 +159,16 @@ template class FoldWinogradFilterTransform final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + FoldWinogradFilterTransform(MLIRContext *context, bool force) + : OpRewritePattern(context, force), forceWinograd(force) {} LogicalResult matchAndRewrite(ConvOp convOp, PatternRewriter &rewriter) const override { + // Attribute control unless forced. + if (!forceWinograd && !convOp->hasAttr(winogradAttr)) + return failure(); + bool isNchw; if (!isValidConv2d(convOp, isNchw)) return failure(); @@ -157,10 +188,11 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { const int64_t kernelSize = kh; const int64_t inputTileSize = outputTileSize + kernelSize - 1; - DenseIntOrFPElementsAttr kernelAttr; - if (!matchPattern(kernel, m_Constant(&kernelAttr))) { + Attribute rawKernelAttr; + if (!matchPattern(kernel, m_Constant(&rawKernelAttr)) || !isa(rawKernelAttr)) { return failure(); } + DenseIntOrFPElementsAttr kernelAttr = cast(rawKernelAttr); Operation *constOp = kernel.getDefiningOp(); ShapedType type = constOp->getResult(0).getType().cast(); @@ -182,11 +214,14 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { auto resultType = RankedTensorType::get(resultShape, elemType); auto foldedKernelAttr = foldFilterTransform(shape, inputTileSize, kernelSize, resultType, - IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat, + IREE::LinalgExt::Winograd::G_4x4_3x3, isSplat, splatValue, nonSplatValues, elemType, isNchw); + rewriter.replaceOpWithNewOp(constOp, foldedKernelAttr); return success(); } + private: + bool forceWinograd; }; } // namespace @@ -283,10 +318,16 @@ template class ConvertConvToWinograd final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + ConvertConvToWinograd(MLIRContext *context, bool force) + : OpRewritePattern(context, force), forceWinograd(force) {} LogicalResult matchAndRewrite(ConvOp convOp, PatternRewriter &rewriter) const override { + // Attribute control unless forced. + if (!forceWinograd && !convOp->hasAttr(winogradAttr)) + return failure(); + bool isNchw; if (!isValidConv2d(convOp, isNchw)) return failure(); @@ -416,10 +457,14 @@ class ConvertConvToWinograd final : public OpRewritePattern { result.replaceAllUsesWith(winogradOutput); return success(); } + private: + bool forceWinograd; }; struct ConvertConv2DToWinogradPass : ConvertConv2DToWinogradBase { + public: + ConvertConv2DToWinogradPass(bool forceWinograd) : forceWinograd(forceWinograd) {} void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); @@ -430,18 +475,20 @@ struct ConvertConv2DToWinogradPass patterns.insert, FoldWinogradFilterTransform, ConvertConvToWinograd, - ConvertConvToWinograd>(context); + ConvertConvToWinograd>(context, forceWinograd); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } + private: + bool forceWinograd; }; } // namespace -std::unique_ptr createConvertConv2DToWinogradPass() { - return std::make_unique(); +std::unique_ptr createConvertConv2DToWinogradPass(bool forceWinograd) { + return std::make_unique(forceWinograd); } } // namespace LinalgExt diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp index 5e1bd34a273e..379e88f25276 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp @@ -210,10 +210,20 @@ static LogicalResult decomposeTiledWinogradInputTransformOp( loc, rewriter.getZeroAttr(elementType)); Value scratch = rewriter.create(loc, inputTileSquare, elementType); + const float *BT{nullptr}; const float *B{nullptr}; - B = IREE::LinalgExt::Winograd::B_6x6_3x3; - BT = IREE::LinalgExt::Winograd::BT_6x6_3x3; + const int64_t outputTileSize = + tiledWinogradInputTransformOp.getOutputTileSize(); + switch (outputTileSize) { + case 4: + B = IREE::LinalgExt::Winograd::B_4x4_3x3; + BT = IREE::LinalgExt::Winograd::BT_4x4_3x3; + break; + default: + B = IREE::LinalgExt::Winograd::B_6x6_3x3; + BT = IREE::LinalgExt::Winograd::BT_6x6_3x3; + } Value BTV = IREE::LinalgExt::createValueFrom2DConstant( BT, inputTileSize, inputTileSize, loc, rewriter); Value BV = IREE::LinalgExt::createValueFrom2DConstant( @@ -435,14 +445,23 @@ static LogicalResult decomposeTiledWinogradOutputTransformOp( "output operand expected to have rank-2"); ShapedType outputType = tiledWinogradOutputTransformOp.getOutputOperandType(); Type elementType = outputType.getElementType(); + const float *AT{nullptr}; const float *A{nullptr}; - A = IREE::LinalgExt::Winograd::A_6x6_3x3; - AT = IREE::LinalgExt::Winograd::AT_6x6_3x3; const int64_t inputTileSize = tiledWinogradOutputTransformOp.getInputTileSize(); const int64_t outputTileSize = tiledWinogradOutputTransformOp.getOutputTileSize(); + switch (outputTileSize) { + case 4: + A = IREE::LinalgExt::Winograd::A_4x4_3x3; + AT = IREE::LinalgExt::Winograd::AT_4x4_3x3; + break; + default: + A = IREE::LinalgExt::Winograd::A_6x6_3x3; + AT = IREE::LinalgExt::Winograd::AT_6x6_3x3; + } + /// The two values below are the transpose(A) [ATV] /// and A [AV] constant matrices that convert the output /// tile from the Winograd domain to the original domain. diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt index 44112ce51db9..a863fd2a5fe0 100644 --- a/runtime/bindings/python/CMakeLists.txt +++ b/runtime/bindings/python/CMakeLists.txt @@ -150,6 +150,11 @@ iree_py_library( "iree/_runtime/scripts/iree_run_trace/__main__.py" "iree/_runtime/scripts/iree_run_module/__main__.py" "iree/_runtime/scripts/iree_tracy_capture/__main__.py" + "iree/runtime/distributed/__init__.py" + "iree/runtime/distributed/distributed.py" + "iree/runtime/distributed/run_rank.py" + "iree/runtime/distributed/sharding_pass_validation.py" + "iree/runtime/distributed/utils.py" PYEXT_DEPS iree_runtime_bindings_python_PyExtRt ) diff --git a/runtime/bindings/python/iree/runtime/distributed/__init__.py b/runtime/bindings/python/iree/runtime/distributed/__init__.py new file mode 100644 index 000000000000..86ee5db110cc --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .distributed import prepare_shards_io_files, run_ranks + +__all__ = ["prepare_shards_io_files", "run_ranks"] diff --git a/runtime/bindings/python/iree/runtime/distributed/distributed.py b/runtime/bindings/python/iree/runtime/distributed/distributed.py new file mode 100644 index 000000000000..31e0a5e13a42 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/distributed.py @@ -0,0 +1,85 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys +import iree.runtime +from iree.runtime.array_interop import DeviceArray +import os +from numpy.typing import ArrayLike +from typing import List, Tuple +import tempfile +import subprocess +from . import utils + + +def prepare_shards_io_files( + inputs: List[List[ArrayLike]], out_dir: str +) -> Tuple[List[str], List[str]]: + input_filepaths = [] + output_filepaths = [] + for i in range(len(inputs)): + input_filepath = os.path.join(out_dir, f"shard_{i}", "input.npy") + input_filepaths.append(input_filepath) + os.makedirs(os.path.dirname(input_filepath)) + utils.write_numpy_arrays_to_file(filepath=input_filepath, arrays=inputs[i]) + output_filepath = os.path.join(out_dir, f"shard_{i}", "output.npy") + output_filepaths.append(output_filepath) + return input_filepaths, output_filepaths + + +def run_ranks( + num_ranks: int, + module_filepath: str, + function: str, + inputs: List[List[ArrayLike]], + driver: str, + call_count: int = 1, + measure_execution_time: bool = False, + warmup: int = 0, +) -> List[List[ArrayLike]]: + """ + Start all ranks with mpirun. + On all ranks run the function |function| from the given module. + Parameters + ---------- + inputs : Function inputs for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + Returns + ------- + The output of the function for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + """ + with tempfile.TemporaryDirectory() as out_dir: + input_filepaths, output_filepaths = prepare_shards_io_files( + inputs=inputs, out_dir=out_dir + ) + hal_driver = iree.runtime.get_driver(driver) + hal_driver.query_available_devices() + subprocess.check_call( + [ + "mpirun", + "--oversubscribe", + "-n", + str(num_ranks), + sys.executable, + os.path.join(os.path.dirname(__file__), "run_rank.py"), + f"--driver={driver}", + f"--module_filepath={module_filepath}", + f"--function={function}", + f"--call_count={call_count}", + ] + + (["--measure_execution_time"] if measure_execution_time else []) + + [ + f"--warmup={warmup}", + "--inputs", + ] + + input_filepaths + + ["--outputs"] + + output_filepaths + ) + return [ + utils.read_numpy_arrays_from_file(out_file) for out_file in output_filepaths + ] diff --git a/runtime/bindings/python/iree/runtime/distributed/run_rank.py b/runtime/bindings/python/iree/runtime/distributed/run_rank.py new file mode 100644 index 000000000000..86761d3172b2 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/run_rank.py @@ -0,0 +1,131 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import iree.runtime +from iree.runtime.array_interop import DeviceArray +from mpi4py import MPI +import utils +import datetime +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run 1 shard.") + parser.add_argument("--driver", type=str, default="local-task", help="Device URI.") + parser.add_argument( + "--module_filepath", type=str, required=True, help="Path to IREE module." + ) + parser.add_argument( + "--function", type=str, required=True, help="Name of function to call." + ) + parser.add_argument( + "--call_count", + type=int, + default=1, + help="How many times to call the function during time measurement.", + ) + parser.add_argument( + "--measure_execution_time", + action="store_true", + default=False, + help="Measure execution time in seconds f64 and append to results.", + ) + parser.add_argument( + "--warmup", + type=int, + default=0, + help="How many warmup calls to do before the actual call that generates the result.", + ) + parser.add_argument( + "--inputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module inputs for all ranks in npy format.", + ) + parser.add_argument( + "--outputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module outputs form all ranks in npy format.", + ) + return parser.parse_args() + + +def run_module( + device: iree.runtime.HalDevice, + module_filepath: str, + function: str, + call_count: int, + input_filepath: str, + output_filepath: str, + measure_execution_time: bool, + warmup: int, +): + config = iree.runtime.Config(device=device) + with open(module_filepath, "rb") as f: + vm_flatbuffer = f.read() + vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer) + bound_module = iree.runtime.load_vm_module(vm_module, config) + input_args = utils.read_numpy_arrays_from_file(input_filepath) + input_args_on_device = [ + iree.runtime.asdevicearray(device, arr) for arr in input_args + ] + for _ in range(warmup): + getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + # Sync all ranks + MPI.COMM_WORLD.barrier() + start_time = datetime.datetime.now() + assert call_count > 0 + for _ in range(call_count): + results = getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + end_time = datetime.datetime.now() + if isinstance(results, DeviceArray): + results = [results] + if measure_execution_time: + if isinstance(results, tuple): + results = list(results) + results.append( + np.array((end_time - start_time).total_seconds() / call_count, dtype=float) + ) + utils.write_numpy_arrays_to_file(filepath=output_filepath, arrays=results) + + +def run_rank( + driver: str, + module_filepath: str, + function: str, + inputs: str, + outputs: str, + call_count: int, + measure_execution_time: bool, + warmup: int, +): + rank = MPI.COMM_WORLD.Get_rank() + hal_driver = iree.runtime.get_driver(driver) + device_infos = hal_driver.query_available_devices() + device = hal_driver.create_device( + device_infos[rank % len(device_infos)]["device_id"] + ) + run_module( + device=device, + module_filepath=module_filepath, + function=function, + call_count=call_count, + input_filepath=inputs[rank], + output_filepath=outputs[rank], + measure_execution_time=measure_execution_time, + warmup=warmup, + ) + + +if __name__ == "__main__": + args = parse_args() + run_rank(**vars(args)) diff --git a/runtime/bindings/python/iree/runtime/distributed/setup.sh b/runtime/bindings/python/iree/runtime/distributed/setup.sh new file mode 100644 index 000000000000..83dca488caa4 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/setup.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +distribution=$(. /etc/os-release;echo $ID$VERSION_ID | sed -e 's/\.//g') +wget -O /tmp/cuda-keyring_1.0-1_all.deb \ + https://developer.download.nvidia.com/compute/cuda/repos/$distribution/x86_64/cuda-keyring_1.0-1_all.deb +sudo dpkg -i /tmp/cuda-keyring_1.0-1_all.deb +sudo apt update +# For CMake to find CUDA when using LLD. +sudo apt -y install lld + +sudo apt -y install libopenmpi-dev +sudo apt -y install libnccl-dev=2.18.1-1+cuda12.1 +pip install mpi4py jax[cpu] diff --git a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py new file mode 100644 index 000000000000..599d6604b8a8 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py @@ -0,0 +1,210 @@ +import iree.compiler +import iree.runtime +import os +from .distributed import run_ranks +import subprocess +from pathlib import Path +from jax._src.lib import xla_client +from jaxlib.xla_client import HloSharding +from typing import List, Tuple, Union +from numpy.typing import ArrayLike +import jax +from jax._src.sharding_impls import GSPMDSharding +import jax._src.interpreters.pxla as pxla +import numpy as np +from datetime import timedelta + +xla_extension = xla_client._xla + + +def compile_mlir(mlir_filepath: str, output_filepath: str, use_cache: bool, **kwargs): + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + iree.compiler.compile_file( + input_file=mlir_filepath, output_file=output_filepath, **kwargs + ) + + +def extract_args_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + return [ + HloSharding.from_proto(sharding) + for sharding in xla_computation.get_hlo_module().spmd_parameters_shardings + ] + + +def extract_results_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + sharding = HloSharding.from_proto( + xla_computation.get_hlo_module().spmd_output_sharding + ) + if len(sharding.tuple_elements()): + return sharding.tuple_elements() + else: + return [sharding] + + +def shard_arg(arg: ArrayLike, sharding: HloSharding) -> List[ArrayLike]: + gspmd_sharding = GSPMDSharding(devices=jax.local_devices(), op_sharding=sharding) + indices = gspmd_sharding.devices_indices_map(arg.shape).values() + sharded_array = pxla.shard_arg( + arg, devices=jax.local_devices(), arg_indices=indices, sharding=gspmd_sharding + ) + return [shard.data for shard in sharded_array.global_shards] + + +def shard_args( + args: List[ArrayLike], shardings: List[HloSharding] +) -> List[List[ArrayLike]]: + assert len(args) == len(shardings) + return [shard_arg(arg, sharding) for arg, sharding in zip(args, shardings)] + + +def assemble_shards(shards: List[ArrayLike], sharding: HloSharding) -> ArrayLike: + if sharding.is_replicated(): + return shards[0] + else: + raise NotImplementedError() + + +def propagate_shardings_and_spmd_partition( + mlir_filepath: str, + output_filepath: str, + num_devices: int, + use_cache: bool, + allow_spmd_sharding_propagation_to_output: int = 1, +): + res = subprocess.run( + [ + "stablehlo-opt", + ( + "--pass-pipeline=builtin.module(stablehlo-xla-sharding-propagation-and-spmd-partitioner{" + "is_spmd=1 " + f"allow_spmd_sharding_propagation_to_output={allow_spmd_sharding_propagation_to_output} " + "allow_spmd_sharding_propagation_to_parameters=1 " + f"num_partitions={num_devices} " + "num_replicas=1})" + ), + mlir_filepath, + ], + check=True, + stdout=subprocess.PIPE, + ) + Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + with open(output_filepath, "wb") as f: + f.write(res.stdout) + + +def swap_shard_axis(arrays: List[ArrayLike]) -> List[List[ArrayLike]]: + """Swap axis 0 with 1.""" + if len(arrays) == 0: + return [] + expected_shards = len(arrays[0]) + res = [[] for _ in range(expected_shards)] + for arr in arrays: + assert len(arr) == expected_shards + for shard in range(expected_shards): + res[shard].append(arr[shard]) + return res + + +def execute_distributed( + num_ranks: int, + mlir_filepath: str, + iree_module_filepath: str, + function: str, + inputs: List[ArrayLike], + driver: str, + measure_execution_time: bool = False, +) -> Union[List[ArrayLike], Tuple[List[ArrayLike], timedelta]]: + with open(mlir_filepath, "r") as f: + mlir_str = f.read() + xla_computation = xla_extension.mlir.mlir_module_to_xla_computation( + mlir_module=mlir_str, use_tuple_args=False, return_tuple=False + ) + args_sharding = extract_args_sharding(xla_computation) + results_sharding = extract_results_sharding(xla_computation) + sharded_args = shard_args(args=inputs, shardings=args_sharding) + sharded_args = swap_shard_axis(sharded_args) + sharded_results = run_ranks( + num_ranks=num_ranks, + module_filepath=iree_module_filepath, + function=function, + inputs=sharded_args, + driver=driver, + ) + sharded_results = swap_shard_axis(sharded_results) + if measure_execution_time: + sharded_results, execution_times = sharded_results + res = [ + assemble_shards(shards=result_shards, sharding=sharding) + for result_shards, sharding in zip(sharded_results, results_sharding) + ] + if measure_execution_time: + res = res, timedelta(seconds=np.max(execution_times)) + return res + + +def validate_sharding_passes( + mlir_filepath: str, + mlir_with_sharding_annotations_filepath: str, + inputs: List[ArrayLike], + function: str, + num_devices: int, + use_cache: bool, + driver: str, + target_backend: str, + output_prefix_path: str, + allow_spmd_sharding_propagation_to_output: int = 1, +): + # Single instance. + iree_module_filepath = ( + f"{output_prefix_path}{os.path.basename(mlir_filepath)}.{driver}.vmfb" + ) + compile_mlir( + mlir_filepath=mlir_filepath, + output_filepath=iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + iree_module = iree.runtime.load_vm_flatbuffer_file( + path=iree_module_filepath, driver=driver + ) + results = iree_module[function](*inputs) + if isinstance(results, iree.runtime.DeviceArray): + results = [results] + + # Distributed. + spmd_mlir_filepath = f"{output_prefix_path}{os.path.basename(mlir_with_sharding_annotations_filepath)}.spmd.mlir" + propagate_shardings_and_spmd_partition( + mlir_filepath=mlir_with_sharding_annotations_filepath, + output_filepath=spmd_mlir_filepath, + num_devices=num_devices, + use_cache=use_cache, + allow_spmd_sharding_propagation_to_output=allow_spmd_sharding_propagation_to_output, + ) + spmd_iree_module_filepath = f"{output_prefix_path}{os.path.basename(spmd_mlir_filepath)}.{target_backend}.vmfb" + compile_mlir( + mlir_filepath=spmd_mlir_filepath, + output_filepath=spmd_iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + spmd_results = execute_distributed( + num_ranks=num_devices, + mlir_filepath=spmd_mlir_filepath, + iree_module_filepath=spmd_iree_module_filepath, + function=function, + inputs=inputs, + driver=driver, + ) + + assert len(results) == len(spmd_results) + for result, spmd_result in zip(results, spmd_results): + np.testing.assert_allclose(result, spmd_result, atol=1e-7) diff --git a/runtime/bindings/python/iree/runtime/distributed/utils.py b/runtime/bindings/python/iree/runtime/distributed/utils.py new file mode 100644 index 000000000000..3581baf354f8 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/utils.py @@ -0,0 +1,26 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from numpy.typing import ArrayLike +from typing import List +import numpy as np + + +def read_numpy_arrays_from_file(filepath: str) -> List[ArrayLike]: + res = [] + with open(filepath, "rb") as f: + while True: + try: + res.append(np.load(f)) + except EOFError: + break + return res + + +def write_numpy_arrays_to_file(filepath: str, arrays: List[ArrayLike]): + with open(filepath, "wb") as f: + for arr in arrays: + np.save(f, np.asarray(arr), allow_pickle=False) diff --git a/runtime/setup.py b/runtime/setup.py index e561c45ea6e5..b7854c02643f 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -274,7 +274,8 @@ def build_configuration(cmake_build_dir, cmake_install_dir, extra_cmake_args=()) "IREE_HAL_DRIVER_VULKAN", "OFF" if platform.system() == "Darwin" else "ON", ), - get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", ""), + get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", + "" if sysconfig.get_platform() != "linux-x86_64" else "rocm;level_zero"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), ] + list(extra_cmake_args) add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER") diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 565da3ae6d53..deab871ba52a 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -81,6 +81,14 @@ static iree_hal_cuda_device_t* iree_hal_cuda_device_cast_unsafe( return (iree_hal_cuda_device_t*)base_value; } +iree_status_t iree_cuda_set_current_thread(iree_hal_device_t* device){ + iree_hal_cuda_device_t* cuda_device = iree_hal_cuda_device_cast(device); + CUDA_RETURN_IF_ERROR(cuda_device->context_wrapper.syms, + cuCtxSetCurrent(cuda_device->context_wrapper.cu_context), + "cuCtxSetCurrent"); + return iree_ok_status(); +} + IREE_API_EXPORT void iree_hal_cuda_device_params_initialize( iree_hal_cuda_device_params_t* out_params) { memset(out_params, 0, sizeof(*out_params)); diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.h b/runtime/src/iree/hal/drivers/cuda/cuda_device.h index 0cc08870c6a0..2af0c77c68bd 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.h +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.h @@ -41,6 +41,8 @@ CUcontext iree_hal_cuda_device_context(iree_hal_device_t* device); iree_hal_cuda_dynamic_symbols_t* iree_hal_cuda_device_dynamic_symbols( iree_hal_device_t* device); +iree_status_t iree_cuda_set_current_thread(iree_hal_device_t* device); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/samples/distributed/example.py b/samples/distributed/example.py new file mode 100644 index 000000000000..ff0989403df9 --- /dev/null +++ b/samples/distributed/example.py @@ -0,0 +1,55 @@ +from iree.runtime.distributed import run_ranks +import iree.compiler +import tempfile +import numpy as np +import os + +""" +Example of distributed execution across 2 devices of a small model +with just an all-reduce operation. +all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) -> [6, 8, 10, 12]. + +Dependecies at: +runtime/bindings/python/iree/runtime/distributed/setup.sh +""" +mlir = """ + func.func @all_reduce_sum(%input : tensor<4xf32>) -> tensor<4xf32> { + %out = "stablehlo.all_reduce"(%input) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %sum = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %sum : tensor + }) {channel_handle = #stablehlo.channel_handle, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + use_global_device_ids} : (tensor<4xf32>) -> tensor<4xf32> + return %out : tensor<4xf32> + } +""" + +inputs = [ + [np.array([1, 2, 3, 4], dtype=np.float32)], + [np.array([5, 6, 7, 8], dtype=np.float32)], +] + +for rank in range(len(inputs)): + print(f"Rank {rank} argument = {inputs[rank]}") + +with tempfile.TemporaryDirectory() as tmp_dir: + module_filepath = os.path.join(tmp_dir, "module.vmfb") + iree.compiler.tools.compile_str( + input_str=mlir, + output_file=module_filepath, + target_backends=["cuda"], + input_type="stablehlo", + ) + + num_ranks = len(inputs) + # Ranks on the 0th axis. + outputs = run_ranks( + num_ranks=num_ranks, + function="all_reduce_sum", + driver="cuda", + module_filepath=module_filepath, + inputs=inputs, + ) + for rank in range(num_ranks): + print(f"Rank {rank} result = {outputs[rank]}")