diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000000..0fe0ceed72d3 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,11 @@ +# This file contains the list of commits to exclude from 'git blame'. +# Such commits do not meaningfully contribute to git history, and include +# large-scale mechanical changes like code formatting style changes. +# +# To set this file as the default ignore file for 'git blame', run: +# ```shell +# git config blame.ignoreRevsFile .git-blame-ignore-revs +# ``` + +# Refresh clang-format +494089d53db4c183b3ba12e36f61ce1c7553984c diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index 73592a7dce86..a21c9a1d7296 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: "Setup build environment" description: "Setup the build environment. An action so that it can be shared between in-tree/out-of-tree jobs" @@ -24,59 +25,59 @@ runs: using: "composite" steps: - - name: Set up Python - if: ${{ runner.arch == 'X64' }} - uses: actions/setup-python@v4 - with: - python-version: '3.11' + - name: Set up Python + if: ${{ runner.arch == 'X64' }} + uses: actions/setup-python@v4 + with: + python-version: '3.11' - - name: Install MLIR Python depends - if: ${{ runner.os != 'Linux' }} - run: | - python -m pip install -r $GITHUB_WORKSPACE/externals/llvm-project/mlir/python/requirements.txt - shell: bash + - name: Install MLIR Python depends + if: ${{ runner.os != 'Linux' }} + run: | + python -m pip install -r $GITHUB_WORKSPACE/externals/llvm-project/mlir/python/requirements.txt + shell: bash - - name: Install PyTorch nightly depends - if: ${{ runner.os != 'Linux' }} - run: | - python -m pip install -r pytorch-requirements.txt - python -m pip install -r build-requirements.txt - shell: bash + - name: Install PyTorch nightly depends + if: ${{ runner.os != 'Linux' }} + run: | + python -m pip install -r pytorch-requirements.txt + python -m pip install -r build-requirements.txt + shell: bash - - name: Install prerequisites (Linux) - if: ${{ runner.os == 'Linux' }} - run: sudo apt-get install --yes ccache ninja-build - shell: bash + - name: Install prerequisites (Linux) + if: ${{ runner.os == 'Linux' }} + run: sudo apt-get install --yes ccache ninja-build + shell: bash - - name: Install prerequisites (macOS) - if: ${{ runner.os == 'macOS' }} - run: brew install ccache ninja - shell: bash + - name: Install prerequisites (macOS) + if: ${{ runner.os == 'macOS' }} + run: brew install ccache ninja + shell: bash - - name: Install prerequisites (Windows) - if: ${{ runner.os == 'Windows' }} - run: | - pip install ninja - choco install ccache --yes - shell: bash + - name: Install prerequisites (Windows) + if: ${{ runner.os == 'Windows' }} + run: | + pip install ninja + choco install ccache --yes + shell: bash - - name: Configure ccache - if: ${{ inputs.cache-enabled == 'true' }} - run: | - rm -rf ${{ github.workspace }}/.ccache - mkdir -p ${{ github.workspace }}/.ccache - ccache --set-config "cache_dir=${{ github.workspace }}/.ccache" - ccache --set-config "compression=true" - ccache --set-config "max_size=300M" - ccache --zero-stats - shell: bash + - name: Configure ccache + if: ${{ inputs.cache-enabled == 'true' }} + run: | + rm -rf ${{ github.workspace }}/.ccache + mkdir -p ${{ github.workspace }}/.ccache + ccache --set-config "cache_dir=${{ github.workspace }}/.ccache" + ccache --set-config "compression=true" + ccache --set-config "max_size=300M" + ccache --zero-stats + shell: bash - - name: Enable ccache - if: ${{ inputs.cache-enabled == 'true' }} - uses: actions/cache@v3 - with: - path: ${{ github.workspace }}/.ccache - key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-${{ inputs.cache-suffix }}- - ${{ runner.os }}- + - name: Enable ccache + if: ${{ inputs.cache-enabled == 'true' }} + uses: actions/cache@v3 + with: + path: ${{ github.workspace }}/.ccache + key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-${{ inputs.cache-suffix }}- + ${{ runner.os }}- diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 5c8d74ee0941..1c0f8f568728 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Roll PyTorch on: @@ -8,139 +9,138 @@ on: jobs: build_linux: name: Manylinux Build - runs-on: a100 + runs-on: torch-mlir-cpubuilder-manylinux-x86-64 # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' steps: - - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'false' - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - - - name: Get LLVM and StableHlo submodules - run: | - set -eo pipefail - cd ${GITHUB_WORKSPACE} - - # Fetching the submodules concurrently may cause problems, so we fetch - # them one after another. - rm -f .git/modules/externals/llvm-project/index.lock - rm -f .git/modules/externals/stablehlo/index.lock - git submodule update --init --recursive externals/llvm-project - git submodule update --init --recursive externals/stablehlo - - - name: Setup ccache - uses: ./.github/actions/setup-build - with: - cache-suffix: 'rollPyTorch' - - - name: Determine nightly PyTorch version - run: | - set -eo pipefail - - cd ${GITHUB_WORKSPACE} - python -m pip install wheel - sudo apt-get install unzip - - # Fetch the most recent nightly torchvision release - VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') - echo "Found torchvision release ${VISION_RELEASE}" - - # Fetch the whl file associated with the nightly torchvision release - rm -f torch*.whl - python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torchvision==${VISION_RELEASE}" - - # Downloading the torchvision WHL also downloads the PyTorch WHL file - # Read the version from the downloaded whl file without extracting it - PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/') - echo "Found torch release ${PT_RELEASE}" - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt - - # Read the commit hash from the downloaded whl file without extracting it - PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'") - echo "Found torch commit hash ${PT_HASH}" - - PT_HASH_CHANGED=0 - echo "${PT_HASH}" | cmp - pytorch-hash.txt --quiet || PT_HASH_CHANGED=$? - echo "${PT_HASH}" > pytorch-hash.txt - rm torch-"${PT_RELEASE}"*.whl - - # Write the release and hash to the environment file so that we can - # retrieve them when creating a PR - echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} - echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV} - echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV} - echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} - - - name: Build and test (out-of-tree), also update ODS and abstract interpretation library - if: env.PT_HASH_CHANGED != '0' - run: | - cd ${GITHUB_WORKSPACE} - TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ - TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ - TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ - TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ - ./build_tools/python_deploy/build_linux_packages.sh - - - name: Post issue comment on build failure - if: failure() - uses: peter-evans/create-or-update-comment@v2 - with: - issue-number: 1690 - body: | - The RollPyTorch action has failed. See [CI log](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details. - - The following information may come handy when fixing the code. - ``` - torch version: ${{ env.PT_RELEASE }} - torch commit hash: ${{ env.PT_HASH }} - torchvision version: ${{ env.PTVISION_RELEASE }} - ``` - - - name: Update PyTorch Build Cache (if running on main branch) - if: github.ref_name == 'main' - id: cache-pytorch - uses: actions/cache@v3 - with: - path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse - key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} - - - name: Commit changes locally - if: env.PT_HASH_CHANGED != '0' - run: | - cd ${GITHUB_WORKSPACE} - git config user.email "torch-mlir@users.noreply.github.com" - git config user.name "Roll PyTorch Action" - git fetch --recurse-submodules=no - git checkout main - git pull origin main - - - name: Create pull request - uses: peter-evans/create-pull-request@v5.0.1 - with: - author: Roll PyTorch Action - branch: rollpytorch - body: | - torch version: ${{ env.PT_RELEASE }} - torch commit hash: ${{ env.PT_HASH }} - torchvision version: ${{ env.PTVISION_RELEASE }} - commit-message: | - update PyTorch version to ${{ env.PT_RELEASE }} - - - torch version: ${{ env.PT_RELEASE }} - - torch commit hash: ${{ env.PT_HASH }} - - torchvision version: ${{ env.PTVISION_RELEASE }} - committer: Roll PyTorch Action - title: update PyTorch version to ${{ env.PT_RELEASE }} - token: ${{ secrets.ROLLPYTORCH_TOKEN0 }} + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* + + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'false' + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + + - name: Get LLVM and StableHlo submodules + run: | + set -eo pipefail + cd ${GITHUB_WORKSPACE} + + # Fetching the submodules concurrently may cause problems, so we fetch + # them one after another. + rm -f .git/modules/externals/llvm-project/index.lock + rm -f .git/modules/externals/stablehlo/index.lock + git submodule update --init --recursive externals/llvm-project + git submodule update --init --recursive externals/stablehlo + + - name: Setup ccache + uses: ./.github/actions/setup-build + with: + cache-suffix: 'rollPyTorch' + + - name: Determine nightly PyTorch version + run: | + set -eo pipefail + + cd ${GITHUB_WORKSPACE} + python -m pip install wheel + sudo apt-get install unzip + + # Fetch the most recent nightly torchvision release + VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') + echo "Found torchvision release ${VISION_RELEASE}" + + # Fetch the whl file associated with the nightly torchvision release + rm -f torch*.whl + python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torchvision==${VISION_RELEASE}" + + # Downloading the torchvision WHL also downloads the PyTorch WHL file + # Read the version from the downloaded whl file without extracting it + PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/') + echo "Found torch release ${PT_RELEASE}" + printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt + printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt + + # Read the commit hash from the downloaded whl file without extracting it + PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | tail -1 | awk '{ print $3 }' | tr -d "'") + echo "Found torch commit hash ${PT_HASH}" + + PT_HASH_CHANGED=0 + echo "${PT_HASH}" | cmp - pytorch-hash.txt --quiet || PT_HASH_CHANGED=$? + echo "${PT_HASH}" > pytorch-hash.txt + rm torch-"${PT_RELEASE}"*.whl + + # Write the release and hash to the environment file so that we can + # retrieve them when creating a PR + echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} + echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV} + echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV} + echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} + + - name: Build and test (out-of-tree), also update ODS and abstract interpretation library + if: env.PT_HASH_CHANGED != '0' + run: | + cd ${GITHUB_WORKSPACE} + TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ + TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ + TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ + TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ + ./build_tools/python_deploy/build_linux_packages.sh + + - name: Post issue comment on build failure + if: failure() + uses: peter-evans/create-or-update-comment@v2 + with: + issue-number: 1690 + body: | + The RollPyTorch action has failed. See [CI log](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details. + + The following information may come handy when fixing the code. + ``` + torch version: ${{ env.PT_RELEASE }} + torch commit hash: ${{ env.PT_HASH }} + torchvision version: ${{ env.PTVISION_RELEASE }} + ``` + + - name: Update PyTorch Build Cache (if running on main branch) + if: github.ref_name == 'main' + id: cache-pytorch + uses: actions/cache@v3 + with: + path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse + key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} + + - name: Commit changes locally + if: env.PT_HASH_CHANGED != '0' + run: | + cd ${GITHUB_WORKSPACE} + git config user.email "torch-mlir@users.noreply.github.com" + git config user.name "Roll PyTorch Action" + git fetch --recurse-submodules=no + git checkout main + git pull origin main + + - name: Create pull request + uses: peter-evans/create-pull-request@v5.0.1 + with: + author: Roll PyTorch Action + branch: rollpytorch + body: | + torch version: ${{ env.PT_RELEASE }} + torch commit hash: ${{ env.PT_HASH }} + torchvision version: ${{ env.PTVISION_RELEASE }} + commit-message: | + update PyTorch version to ${{ env.PT_RELEASE }} + + - torch version: ${{ env.PT_RELEASE }} + - torch commit hash: ${{ env.PT_HASH }} + - torchvision version: ${{ env.PTVISION_RELEASE }} + committer: Roll PyTorch Action + title: update PyTorch version to ${{ env.PT_RELEASE }} + token: ${{ secrets.ROLLPYTORCH_TOKEN0 }} diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index b3cb3b8fb165..23f2addbe5af 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -1,8 +1,9 @@ +# yamllint disable rule:line-length name: Bazel Build and Test on: push: - branches: [ main ] + branches: [main] workflow_dispatch: # Ensure that only a single job or workflow using the same @@ -24,90 +25,90 @@ jobs: runs-on: ubuntu-latest steps: - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* - - name: Checkout torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' + - name: Checkout torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' - # Continually update cache even if there's a "hit" during - # restore to avoid the cache going stale over time - # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache - - name: Setup cache for bazel - uses: actions/cache@v3 - with: - path: ~/.cache/bazel - key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} - restore-keys: | - torch_mlir-bazel-build-cache-${{ runner.os }} + # Continually update cache even if there's a "hit" during + # restore to avoid the cache going stale over time + # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache + - name: Setup cache for bazel + uses: actions/cache@v3 + with: + path: ~/.cache/bazel + key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} + restore-keys: | + torch_mlir-bazel-build-cache-${{ runner.os }} - # Change bazel cache directory to root ownership - # to allow writing to it from within the docker container. - # If no cache hits, this directory is not present - # so don't run chown (will error otherwise). - - name: Set bazel cache permissions - run: | - if [ -d "${HOME}/.cache/bazel" ]; then - sudo chown -R root:root "${HOME}/.cache/bazel" - fi + # Change bazel cache directory to root ownership + # to allow writing to it from within the docker container. + # If no cache hits, this directory is not present + # so don't run chown (will error otherwise). + - name: Set bazel cache permissions + run: | + if [ -d "${HOME}/.cache/bazel" ]; then + sudo chown -R root:root "${HOME}/.cache/bazel" + fi - - name: Build docker image - run: | - docker build -f utils/bazel/docker/Dockerfile \ - -t torch-mlir:ci \ - . + - name: Build docker image + run: | + docker build -f utils/bazel/docker/Dockerfile \ + -t torch-mlir:ci \ + . - - name: Verify buildifier was run (bazel lint) - run: | - docker run --rm \ - -v "$(pwd)":"/opt/src/torch-mlir" \ - -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ - torch-mlir:ci \ - bazel run @torch-mlir//:buildifier - if [ -n "$(git status --porcelain)" ]; then - echo "Please 'bazel run @torch-mlir//:buildifier' and commit changes." - exit 1 - fi + - name: Verify buildifier was run (bazel lint) + run: | + docker run --rm \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:ci \ + bazel run @torch-mlir//:buildifier + if [ -n "$(git status --porcelain)" ]; then + echo "Please 'bazel run @torch-mlir//:buildifier' and commit changes." + exit 1 + fi - - name: Bazel build torch-mlir - run: | - docker run --rm \ - -v "$(pwd)":"/opt/src/torch-mlir" \ - -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ - torch-mlir:ci \ - bazel build @torch-mlir//:torch-mlir-opt + - name: Bazel build torch-mlir + run: | + docker run --rm \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:ci \ + bazel build @torch-mlir//:torch-mlir-opt - - name: Bazel test torch-mlir (lit tests) - run: | - docker run --rm \ - -v "$(pwd)":"/opt/src/torch-mlir" \ - -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ - torch-mlir:ci \ - bazel test @torch-mlir//test/... + - name: Bazel test torch-mlir (lit tests) + run: | + docker run --rm \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:ci \ + bazel test @torch-mlir//test/... - # Switch back bazel cache directory to user ownership - # to allow GHA post-cache step to save cache without - # permissions issue. - - name: Switch bazel cache permissions - run: | - if [ -d "${HOME}/.cache/bazel" ]; then - sudo chown -R "$USER":"$USER" "${HOME}/.cache/bazel" - fi + # Switch back bazel cache directory to user ownership + # to allow GHA post-cache step to save cache without + # permissions issue. + - name: Switch bazel cache permissions + run: | + if [ -d "${HOME}/.cache/bazel" ]; then + sudo chown -R "$USER":"$USER" "${HOME}/.cache/bazel" + fi - - name: Send mail - if: failure() - uses: dawidd6/action-send-mail@v3 - with: - server_address: ${{ secrets.SMTP_SERVER }} - server_port: ${{ secrets.SMTP_PORT }} - username: ${{ secrets.SMTP_USERNAME }} - password: ${{ secrets.SMTP_PASSWORD }} - subject: GitHub Action Bazel Build and Test failed! - body: Bazel Build job failed! See https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} for more information. - to: ${{ secrets.MAIL_RECEIVER }} - from: Torch-MLIR Bazel Build GitHub Actions + - name: Send mail + if: failure() + uses: dawidd6/action-send-mail@v3 + with: + server_address: ${{ secrets.SMTP_SERVER }} + server_port: ${{ secrets.SMTP_PORT }} + username: ${{ secrets.SMTP_USERNAME }} + password: ${{ secrets.SMTP_PASSWORD }} + subject: GitHub Action Bazel Build and Test failed! + body: Bazel Build job failed! See https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} for more information. + to: ${{ secrets.MAIL_RECEIVER }} + from: Torch-MLIR Bazel Build GitHub Actions diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index b6726cf90b52..817ae6d01461 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -1,9 +1,10 @@ +# yamllint disable rule:line-length name: Build and Test on: - pull_request: - push: - branches: [ main, feature/* ] + #pull_request: + #push: + # branches: [ main, feature/* ] workflow_dispatch: # Ensure that only a single job or workflow using the same @@ -28,8 +29,8 @@ jobs: strategy: fail-fast: true matrix: - os-arch: [ubuntu-x86_64] #, macos-arm64, windows-x86_64] - llvm-build: [in-tree] #, out-of-tree] + os-arch: [macos-arm64, windows-x86_64] + llvm-build: [in-tree, out-of-tree] torch-binary: [ON] torch-version: [nightly, stable] exclude: @@ -56,101 +57,101 @@ jobs: runs-on: ${{ matrix.os }} steps: - - - name: Prepare workspace - if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - - name: Checkout torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - - - name: Fetch PyTorch commit hash - if: ${{ matrix.os-arch != 'windows-x86_64' }} - run: | - PT_HASH="$(cat ${GITHUB_WORKSPACE}/pytorch-hash.txt)" - echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} - - - name: Setup ccache - uses: ./.github/actions/setup-build - with: - cache-suffix: 'build-${{ matrix.llvm-build }}-${{ matrix.torch-version }}' - torch-version: ${{ matrix.torch-version }} - - - name: Set up Visual Studio shell - if: ${{ matrix.os-arch == 'windows-x86_64' }} - uses: egor-tensin/vs-shell@v2 - with: - arch: x64 - - - name: Try to Restore PyTorch Build Cache - if: ${{ matrix.torch-binary == 'OFF' }} - id: cache-pytorch - uses: actions/cache/restore@v3 - with: - path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse - key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} - - - name: Build and Test os-arch='ubuntu-x86_64' llvm-build='${{ matrix.llvm-build }}' torch-binary='${{ matrix.torch-binary }}' - if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} - run: | - cd $GITHUB_WORKSPACE - TORCH_MLIR_SRC_PYTORCH_BRANCH="$(cat pytorch-hash.txt)" \ - TM_PACKAGES="${{ matrix.llvm-build }}" \ - TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" \ - TM_PYTORCH_INSTALL_WITHOUT_REBUILD="${{ steps.cache-pytorch.outputs.cache-hit }}" \ - TM_TORCH_VERSION="${{ matrix.torch-version }}" \ - ./build_tools/python_deploy/build_linux_packages.sh - - - name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' - # cross compile, can't test arm64 - if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }} - run: | - # TODO: Reenable LTC after build on macOS-arm64 is fixed (https://github.com/llvm/torch-mlir/issues/1253) - cmake -GNinja -Bbuild_arm64 \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_LINKER=lld \ - -DCMAKE_OSX_ARCHITECTURES=arm64 \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DLLVM_ENABLE_PROJECTS=mlir \ - -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ - -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ - -DLLVM_TARGETS_TO_BUILD=AArch64 \ - -DLLVM_USE_HOST_TOOLS=ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ - -DTORCH_MLIR_ENABLE_LTC=OFF \ - -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ - -DMACOSX_DEPLOYMENT_TARGET=12.0 \ - -DPython3_EXECUTABLE="$(which python)" \ - $GITHUB_WORKSPACE/externals/llvm-project/llvm - - - name: Build torch-mlir (cross-compile) - if: ${{ matrix.os-arch == 'macos-arm64' }} - run: | - cmake --build build_arm64 - - - name: Build (Windows) - if: ${{ matrix.os-arch == 'windows-x86_64' }} - shell: bash - run: ./build_tools/python_deploy/build_windows_ci.sh - - - name: Save PyTorch Build Cache - if: ${{ github.ref_name == 'main' && matrix.torch-binary == 'OFF' }} - uses: actions/cache/save@v3 - with: - path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse - key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} - - - name: Print ccache statistics - shell: bash - run: ccache --show-stats + - name: Prepare workspace + if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* + + - name: Checkout torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + fetch-depth: 0 + + - name: Fetch PyTorch commit hash + if: ${{ matrix.os-arch != 'windows-x86_64' }} + run: | + PT_HASH="$(cat ${GITHUB_WORKSPACE}/pytorch-hash.txt)" + echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} + + - name: Setup ccache + uses: ./.github/actions/setup-build + with: + cache-suffix: 'build-${{ matrix.llvm-build }}-${{ matrix.torch-version }}' + torch-version: ${{ matrix.torch-version }} + + - name: Set up Visual Studio shell + if: ${{ matrix.os-arch == 'windows-x86_64' }} + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 + + - name: Try to Restore PyTorch Build Cache + if: ${{ matrix.torch-binary == 'OFF' }} + id: cache-pytorch + uses: actions/cache/restore@v3 + with: + path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse + key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} + + - name: Build and Test os-arch='ubuntu-x86_64' llvm-build='${{ matrix.llvm-build }}' torch-binary='${{ matrix.torch-binary }}' + if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} + run: | + cd $GITHUB_WORKSPACE + TORCH_MLIR_SRC_PYTORCH_BRANCH="$(cat pytorch-hash.txt)" \ + TM_PACKAGES="${{ matrix.llvm-build }}" \ + TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" \ + TM_PYTORCH_INSTALL_WITHOUT_REBUILD="${{ steps.cache-pytorch.outputs.cache-hit }}" \ + TM_TORCH_VERSION="${{ matrix.torch-version }}" \ + ./build_tools/python_deploy/build_linux_packages.sh + + - name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' + # cross compile, can't test arm64 + if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }} + run: | + # TODO: Reenable LTC after build on macOS-arm64 is fixed (https://github.com/llvm/torch-mlir/issues/1253) + cmake -GNinja -Bbuild_arm64 \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_LINKER=lld \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ + -DLLVM_TARGETS_TO_BUILD=AArch64 \ + -DLLVM_USE_HOST_TOOLS=ON \ + -DLLVM_ENABLE_ZSTD=OFF \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ + -DTORCH_MLIR_ENABLE_LTC=OFF \ + -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ + -DMACOSX_DEPLOYMENT_TARGET=12.0 \ + -DPython3_EXECUTABLE="$(which python)" \ + $GITHUB_WORKSPACE/externals/llvm-project/llvm + + - name: Build torch-mlir (cross-compile) + if: ${{ matrix.os-arch == 'macos-arm64' }} + run: | + cmake --build build_arm64 + + - name: Build (Windows) + if: ${{ matrix.os-arch == 'windows-x86_64' }} + shell: bash + run: ./build_tools/python_deploy/build_windows_ci.sh + + - name: Save PyTorch Build Cache + if: ${{ github.ref_name == 'main' && matrix.torch-binary == 'OFF' }} + uses: actions/cache/save@v3 + with: + path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse + key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} + + - name: Print ccache statistics + shell: bash + run: ccache --show-stats diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 278590ef3511..4e4dd0a6fb95 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Release Build on: @@ -20,73 +21,72 @@ jobs: packages: write strategy: matrix: - package: [ torch-mlir ] - py_version: [ cp38-cp38, cp310-cp310 ] # cp311-cp311 + package: [torch-mlir] + py_version: [cp38-cp38, cp310-cp310] # cp311-cp311 torch-version: [stable] # nightly steps: + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + fetch-depth: 0 - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Build Python wheels and smoke test. + run: | + cd $GITHUB_WORKSPACE + TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} + printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version + TM_SKIP_TESTS=ON \ + TM_PYTHON_VERSIONS=${{ matrix.py_version }} \ + TM_PACKAGES=${{ matrix.package }} \ + TM_TORCH_VERSION="${{ matrix.torch-version }}" \ + ./build_tools/python_deploy/build_linux_packages.sh + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Build Python wheels and smoke test. - run: | - cd $GITHUB_WORKSPACE - TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} - printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_SKIP_TESTS=ON \ - TM_PYTHON_VERSIONS=${{ matrix.py_version }} \ - TM_PACKAGES=${{ matrix.package }} \ - TM_TORCH_VERSION="${{ matrix.torch-version }}" \ - ./build_tools/python_deploy/build_linux_packages.sh - - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist build_linux_arm64: if: false @@ -94,69 +94,68 @@ jobs: runs-on: linux-arm64 strategy: matrix: - package: [ torch-mlir ] - py_version: [ cp311-cp311 ] + package: [torch-mlir] + py_version: [cp311-cp311] steps: + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - fetch-depth: 0 + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + fetch-depth: 0 - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Build Python wheels and smoke test. - run: | - cd $GITHUB_WORKSPACE - TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} - printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Build Python wheels and smoke test. + run: | + cd $GITHUB_WORKSPACE + TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} + printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version + TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist build_macos: if: false @@ -164,60 +163,60 @@ jobs: runs-on: macos-latest strategy: matrix: - package: [ torch-mlir ] + package: [torch-mlir] steps: - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Build Python wheels and smoke test. - run: | - cd $GITHUB_WORKSPACE - python -m pip install wheel - TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} - printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - sudo ./build_tools/python_deploy/install_macos_deps.sh - packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Build Python wheels and smoke test. + run: | + cd $GITHUB_WORKSPACE + python -m pip install wheel + TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} + printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version + sudo ./build_tools/python_deploy/install_macos_deps.sh + packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist build_windows: if: false @@ -225,64 +224,64 @@ jobs: runs-on: windows-latest strategy: matrix: - package: [ torch-mlir ] + package: [torch-mlir] steps: - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Set up Visual Studio shell - uses: egor-tensin/vs-shell@v2 - with: - arch: x64 - - name: Build Python wheels and smoke test. - shell: pwsh - run: | - $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1' - $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0' - $env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}' - ./build_tools/python_deploy/build_windows.ps1 + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Set up Visual Studio shell + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 + - name: Build Python wheels and smoke test. + shell: pwsh + run: | + $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1' + $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0' + $env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}' + ./build_tools/python_deploy/build_windows.ps1 - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - continue-on-error: true - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp ./wheelhouse/torch_mlir*.whl dist/ + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + continue-on-error: true + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp ./wheelhouse/torch_mlir*.whl dist/ - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist publish_releases: runs-on: ubuntu-latest @@ -291,35 +290,35 @@ jobs: actions: write packages: write needs: - - build_linux - #- build_linux_arm64 - #- build_macos - #- build_windows + - build_linux + #- build_linux_arm64 + #- build_macos + #- build_windows # Publish even if one of the builds failed if: ${{ always() }} steps: - - name: Invoke Publish Releases Page - uses: benc-uk/workflow-dispatch@v1 - with: - workflow: Publish releases page - token: ${{ secrets.GITHUB_TOKEN }} + - name: Invoke Publish Releases Page + uses: benc-uk/workflow-dispatch@v1 + with: + workflow: Publish releases page + token: ${{ secrets.GITHUB_TOKEN }} - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - # - # We're temporarily disabling pypi publishing until we can fix audit wheel - # ODR torch issues. See https://github.com/llvm/torch-mlir/issues/1709 - # - #- name: Download wheels for publishing to PyPI - # uses: actions/download-artifact@v3 - # with: - # name: wheels - # path: dist - #- name: Publish to PyPI - # if: github.event.inputs.release_id != '' - # uses: pypa/gh-action-pypi-publish@v1.5.1 - # with: - # password: ${{ secrets.PYPI_API_TOKEN }} + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + # + # We're temporarily disabling pypi publishing until we can fix audit wheel + # ODR torch issues. See https://github.com/llvm/torch-mlir/issues/1709 + # + #- name: Download wheels for publishing to PyPI + # uses: actions/download-artifact@v3 + # with: + # name: wheels + # path: dist + #- name: Publish to PyPI + # if: github.event.inputs.release_id != '' + # uses: pypa/gh-action-pypi-publish@v1.5.1 + # with: + # password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000000..350488ee5195 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,82 @@ +# yamllint disable rule:line-length +name: CI + +on: + workflow_dispatch: + workflow_call: + pull_request: + branches: [main, feature/*] + push: + branches: [main, feature/*] + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). + group: ci-build-test-cpp-linux-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + build-test-linux: + strategy: + fail-fast: true + matrix: + torch-version: [nightly, stable] + name: Build and Test (Linux, torch-${{ matrix.torch-version }}, assertions) + runs-on: ubuntu-latest + env: + CACHE_DIR: ${{ github.workspace }}/.container-cache + steps: + - name: Configure local git mirrors + run: | + # Our stock runners have access to certain local git caches. If these + # files are available, it will prime the cache and configure git to + # use them. Practically, this eliminates network/latency for cloning + # llvm. + if [[ -x /gitmirror/scripts/trigger_update_mirrors.sh ]]; then + /gitmirror/scripts/trigger_update_mirrors.sh + /gitmirror/scripts/git_config.sh + fi + - name: "Checking out repository" + uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 + with: + submodules: true + + - name: Runner setup + run: | + sudo apt-get update + sudo apt-get install -y ccache clang + + - name: Enable cache + uses: actions/cache/restore@v3 + with: + path: ${{ env.CACHE_DIR }} + key: build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-${{ github.sha }} + restore-keys: | + build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2- + + - name: Install python deps (torch-${{ matrix.torch-version }}) + run: | + export cache_dir="${{ env.CACHE_DIR }}" + bash build_tools/ci/install_python_deps.sh ${{ matrix.torch-version }} + + - name: Build project + run: | + export cache_dir="${{ env.CACHE_DIR }}" + bash build_tools/ci/build_posix.sh + + - name: Save cache + uses: actions/cache/save@v3 + if: ${{ !cancelled() }} + with: + path: ${{ env.CACHE_DIR }} + key: build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-${{ github.sha }} + + - name: Integration tests (torch-${{ matrix.torch-version }}) + run: | + bash build_tools/ci/test_posix.sh ${{ matrix.torch-version }} + + - name: Check generated sources (torch-nightly only) + if: ${{ matrix.torch-version == 'nightly' }} + run: | + bash build_tools/ci/check_generated_sources.sh diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index 5ee7047c5d8d..1e14b8feece8 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length # See: https://github.com/llvm/torch-mlir/issues/1374 name: Publish releases page diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000000..364e9fa9d378 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +# yamllint disable rule:line-length +name: Lint Checks + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + static_lint_checks: + name: Static Lint Checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + # `git-clang-format` needs access to the commit before the + # current merge commit to know what changes to format. + fetch-depth: 2 + - name: Validate GitHub Actions yaml files + run: | + yamllint ./.github/workflows/ ./.github/actions/ + - name: Check clang-format + run: | + wget -q https://raw.githubusercontent.com/llvm/llvm-project/main/clang/tools/clang-format/git-clang-format + python3 git-clang-format --diff HEAD~1 diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index 7247a3683281..58a91fd1d409 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: RollPyTorch Merge on: @@ -15,19 +16,19 @@ jobs: github.event.workflow_run.conclusion == 'success' steps: - # Fetch the repo first so that the gh command knows where to look for the PR - - name: Fetch Repo - uses: actions/checkout@v3 - with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + # Fetch the repo first so that the gh command knows where to look for the PR + - name: Fetch Repo + uses: actions/checkout@v3 + with: + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - - name: Merge RollPyTorch PR - run: | - for pr_id in ${{ join(github.event.workflow_run.pull_requests.*.number, ' ') }} - do - echo "Merging PR: $pr_id" - gh pr merge $pr_id --delete-branch --squash - done - shell: bash - env: - GH_TOKEN: ${{ secrets.ROLLPYTORCH_TOKEN1 }} + - name: Merge RollPyTorch PR + run: | + for pr_id in ${{ join(github.event.workflow_run.pull_requests.*.number, ' ') }} + do + echo "Merging PR: $pr_id" + gh pr merge $pr_id --delete-branch --squash + done + shell: bash + env: + GH_TOKEN: ${{ secrets.ROLLPYTORCH_TOKEN1 }} diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index bec2e21282f0..9c54c8b3f0ef 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Release oneshot snapshot package on: diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index 3648152de9a1..7899110fa0e3 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -1,9 +1,9 @@ +# yamllint disable rule:line-length name: Release snapshot package on: - schedule: - - cron: '17 4 * * *' - + # schedule: + # - cron: '0 11 * * *' workflow_dispatch: jobs: diff --git a/.gitignore b/.gitignore index 5c407428929c..00a5bc96f221 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ externals/pytorch/ libtorch* /build/ +/setup_build/ __pycache__ *.pyc diff --git a/CMakeLists.txt b/CMakeLists.txt index ccbe7ccb3a98..44f02ac6af38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,10 +25,13 @@ project(torch-mlir LANGUAGES CXX C) set(CMAKE_C_STANDARD 11) set(CMAKE_CXX_STANDARD 17) +include(CMakeDependentOption) + #------------------------------------------------------------------------------- # Project options #------------------------------------------------------------------------------- +option(TORCH_MLIR_ENABLE_WERROR_FLAG "Enable `-Werror` flag on supported directories, treat error as warning" OFF) option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON) option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON) @@ -43,24 +46,21 @@ endif() option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF) -# PT1 options. -option(TORCH_MLIR_ENABLE_PROJECT_PT1 "Enables the PyTorch1 project under projects/pt1" OFF) -# TODO: Rename/scope these. They use historic names for now to ease migration -# burden. -option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) -option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) -option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF) -if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF) - set(TORCH_MLIR_ENABLE_LTC OFF) -endif() -# Force enable the PT1 project if either the JIT_IR_IMPORTER or LTC is enabled. -if(NOT TORCH_MLIR_ENABLE_PROJECT_PT1) - if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) - message(STATUS "Enabling projects/pt1 because features requiring it are enabled") - set(TORCH_MLIR_ENABLE_PROJECT_PT1 ON) +# PyTorch native extension gate. If OFF, then no features which depend on +# native extensions will be built. +option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension features" ON) +cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) +cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) + +option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF) + +macro(torch_mlir_enable_werror) + if(TORCH_MLIR_ENABLE_WERROR_FLAG) + if(NOT MSVC) + add_compile_options(-Werror) + endif() endif() -endif() +endmacro() #------------------------------------------------------------------------------- # Configure out-of-tree vs in-tree build @@ -235,4 +235,16 @@ endif() # Sub-projects #------------------------------------------------------------------------------- +# Sub-projects can bundle additional PyTorch extensions by adding them to this +# source target. It is typically empty unless if features are enabled. +if(MLIR_ENABLE_BINDINGS_PYTHON) + declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources) +endif() + +# Build projects first as it may populate additional Python deps. add_subdirectory(projects) + +# Finish with top-level Python bindings so it can handle additional deps. +if(MLIR_ENABLE_BINDINGS_PYTHON) + add_subdirectory(python) +endif() \ No newline at end of file diff --git a/README.md b/README.md index cc479d8d35eb..1b0fff13bdb3 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ We have few paths to lower down to the Torch MLIR Dialect. - LazyTensorCore Read more details [here](docs/ltc_backend.md). - We also have basic TorchDynamo/PyTorch 2.0 support, see our - [long-term roadmap](docs/long_term_roadmap.md) and + [long-term roadmap](docs/roadmap.md) and [Thoughts on PyTorch 2.0](https://discourse.llvm.org/t/thoughts-on-pytorch-2-0/67000/3) for more details. @@ -52,7 +52,7 @@ Meeting links can be found [here](https://discourse.llvm.org/t/new-community-mee ## Install torch-mlir snapshot -At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.11 on Linux and macOS. +At the time of writing, we release pre-built snapshots of torch-mlir for Python 3.11. If you have Python 3.11, the following commands initialize a virtual environment. ```shell diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh new file mode 100755 index 000000000000..fec5e252e8d7 --- /dev/null +++ b/build_tools/ci/build_posix.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" +build_dir="$repo_root/build" +install_dir="$repo_root/install" +mkdir -p "$build_dir" +build_dir="$(cd $build_dir && pwd)" +cache_dir="${cache_dir:-}" + +# Setup cache dir. +if [ -z "${cache_dir}" ]; then + cache_dir="${repo_root}/.build-cache" + mkdir -p "${cache_dir}" + cache_dir="$(cd ${cache_dir} && pwd)" +fi +echo "Caching to ${cache_dir}" +mkdir -p "${cache_dir}/ccache" +mkdir -p "${cache_dir}/pip" + +python="$(which python)" +echo "Using python: $python" + +export CMAKE_TOOLCHAIN_FILE="$this_dir/linux_default_toolchain.cmake" +export CC=clang +export CXX=clang++ +export CCACHE_DIR="${cache_dir}/ccache" +export CCACHE_MAXSIZE="350M" +export CMAKE_C_COMPILER_LAUNCHER=ccache +export CMAKE_CXX_COMPILER_LAUNCHER=ccache + +# Clear ccache stats. +ccache -z + +cd $repo_root + +echo "::group::CMake configure" +cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ + -GNinja \ + -DCMAKE_BUILD_TYPE=Release \ + -DPython3_EXECUTABLE="$(which python)" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DTORCH_MLIR_ENABLE_WERROR_FLAG=ON \ + -DCMAKE_INSTALL_PREFIX="$install_dir" \ + -DCMAKE_INSTALL_LIBDIR=lib \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_LTC=ON +echo "::endgroup::" + +echo "::group::Build" +cmake --build "$build_dir" --target tools/torch-mlir/all -- -k 0 +echo "::endgroup::" + +echo "::group::Unit tests" +cmake --build $repo_root/build --target check-torch-mlir-all +echo "::endgroup::" + +# Show ccache stats. +ccache --show-stats diff --git a/build_tools/ci/check_generated_sources.sh b/build_tools/ci/check_generated_sources.sh new file mode 100755 index 000000000000..719e221d71ba --- /dev/null +++ b/build_tools/ci/check_generated_sources.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" + +function _check_file_not_changed_by() { + # _check_file_not_changed_by + cmd="$1" + file="$2" + file_backup="$PWD/$(basename $file)" + file_new="$PWD/$(basename $file).new" + # Save the original file. + cp "$file" "$file_backup" + # Run the command to regenerate it. + "$1" || return 1 + # Save the new generated file. + cp "$file" "$file_new" + # Restore the original file. We want this function to not change the user's + # working tree state. + mv "$file_backup" "$file" + # We use git-diff as "just a diff program" (no SCM stuff) because it has + # nicer output than regular `diff`. + if ! git diff --no-index --quiet "$file" "$file_new"; then + echo "#######################################################" + echo "Generated file '${file}' is not up to date (see diff below)" + echo ">>> Please run '${cmd}' to update it <<<" + echo "#######################################################" + git diff --no-index --color=always "$file" "$file_new" + # TODO: Is there a better cleanup strategy that doesn't require duplicating + # this inside and outside the `if`? + rm "$file_new" + return 1 + fi + rm "$file_new" +} + +echo "::group:: Check that update_abstract_interp_lib.sh has been run" +_check_file_not_changed_by $repo_root/build_tools/update_abstract_interp_lib.sh $repo_root/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +echo "::endgroup::" + +echo "::group:: Check that update_torch_ods.sh has been run" +_check_file_not_changed_by $repo_root/build_tools/update_torch_ods.sh $repo_root/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +echo "::endgroup::" diff --git a/build_tools/ci/install_python_deps.sh b/build_tools/ci/install_python_deps.sh new file mode 100755 index 000000000000..6b49689ce8ea --- /dev/null +++ b/build_tools/ci/install_python_deps.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" +torch_version="${1:-unknown}" + +echo "::group::installing llvm python deps" +python -m pip install --no-cache-dir -r $repo_root/externals/llvm-project/mlir/python/requirements.txt +echo "::endgroup::" + +case $torch_version in + nightly) + echo "::group::installing nightly torch" + python3 -m pip install --no-cache-dir -r $repo_root/requirements.txt + python3 -m pip install --no-cache-dir -r $repo_root/torchvision-requirements.txt + echo "::endgroup::" + ;; + stable) + echo "::group::installing stable torch" + python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu + python3 -m pip install --no-cache-dir -r $repo_root/build-requirements.txt + echo "::endgroup::" + ;; + *) + echo "Unrecognized torch version '$torch_version' (specify 'nightly' or 'stable' with cl arg)" + exit 1 + ;; +esac + +echo "::group::installing test requirements" +python -m pip install --no-cache-dir -r $repo_root/test-requirements.txt +echo "::endgroup::" diff --git a/build_tools/ci/linux_default_toolchain.cmake b/build_tools/ci/linux_default_toolchain.cmake new file mode 100644 index 000000000000..4e0c36c71be7 --- /dev/null +++ b/build_tools/ci/linux_default_toolchain.cmake @@ -0,0 +1,14 @@ +message(STATUS "Enabling thin archives (static libraries will not be relocatable)") +set(CMAKE_C_ARCHIVE_APPEND " qT ") +set(CMAKE_CXX_ARCHIVE_APPEND " qT ") +set(CMAKE_C_ARCHIVE_CREATE " crT ") +set(CMAKE_CXX_ARCHIVE_CREATE " crT ") + +set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld -Wl,--gdb-index") +set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld -Wl,--gdb-index") +set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld -Wl,--gdb-index") + +set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -gsplit-dwarf -ggnu-pubnames") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -gsplit-dwarf -ggnu-pubnames") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} -gsplit-dwarf -ggnu-pubnames") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -gsplit-dwarf -ggnu-pubnames") diff --git a/build_tools/ci/test_posix.sh b/build_tools/ci/test_posix.sh new file mode 100755 index 000000000000..71a22d0f714e --- /dev/null +++ b/build_tools/ci/test_posix.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" +torch_version="${1:-unknown}" + +export PYTHONPATH="$repo_root/build/tools/torch-mlir/python_packages/torch_mlir:$repo_root/projects/pt1" + +echo "::group::Run Linalg e2e integration tests" +python -m e2e_testing.main --config=linalg -v +echo "::endgroup::" + +echo "::group::Run make_fx + TOSA e2e integration tests" +python -m e2e_testing.main --config=make_fx_tosa -v +echo "::endgroup::" + +echo "::group::Run TOSA e2e integration tests" +python -m e2e_testing.main --config=tosa -v +echo "::endgroup::" + +echo "::group::Run Stablehlo e2e integration tests" +python -m e2e_testing.main --config=stablehlo -v +echo "::endgroup::" + +echo "::group::Run ONNX e2e integration tests" +python -m e2e_testing.main --config=onnx -v +echo "::endgroup::" + +case $torch_version in + nightly) + # Failing with: NotImplementedError: + # Could not run 'aten::empty.memory_format' with arguments from the 'Lazy' backend. + # As of 2024-01-07 + # echo "::group::Run Lazy Tensor Core e2e integration tests" + # python -m e2e_testing.main --config=lazy_tensor_core -v + # echo "::endgroup::" + + # TODO: There is one failing test in this group on stable. It could + # be xfailed vs excluding entirely. + echo "::group::Run TorchDynamo e2e integration tests" + python -m e2e_testing.main --config=torchdynamo -v + echo "::endgroup::" + ;; + stable) + ;; + *) + echo "Unrecognized torch version '$torch_version' (specify 'nightly' or 'stable' with cl arg)" + exit 1 + ;; +esac diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 3df3dfb4f453..17512de87a45 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -305,6 +305,9 @@ function test_in_tree() { echo ":::: Run Linalg e2e integration tests" python -m e2e_testing.main --config=linalg -v + echo ":::: Run Onnx e2e integration tests" + python -m e2e_testing.main --config=onnx -v + # Dynamo is changing a lot in nightly versions, and thus the implementation # tends to become incompatible to the stable version. echo ":::: Run TorchDynamo e2e integration tests" @@ -351,7 +354,6 @@ function setup_venv() { echo ":::: Using stable dependencies" python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/stable-requirements.txt python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt - python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt ;; *) echo "Unrecognized torch version '$torch_version'" @@ -359,6 +361,7 @@ function setup_venv() { ;; esac + python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt } function build_out_of_tree() { diff --git a/build_tools/write_env_file.sh b/build_tools/write_env_file.sh index 8f3c9a59357f..05179c56a07c 100755 --- a/build_tools/write_env_file.sh +++ b/build_tools/write_env_file.sh @@ -13,7 +13,7 @@ portable_realpath() { td="$(portable_realpath "$(dirname "$0")"/..)" build_dir="$(portable_realpath "${TORCH_MLIR_BUILD_DIR:-$td/build}")" -python_packages_dir="$build_dir/python_packages" +python_packages_dir="$build_dir/tools/torch-mlir/python_packages" write_env_file() { echo "Updating $build_dir/.env file" diff --git a/docs/add_ops.md b/docs/add_ops.md new file mode 100644 index 000000000000..1805f1700b47 --- /dev/null +++ b/docs/add_ops.md @@ -0,0 +1,164 @@ +# How to Add Ops to Torch-Mlir + +Collected links and contacts for how to add ops to torch-mlir. + + +
+Turbine Camp: Start Here +This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir. + +Written & maintained by @renxida + +Guides by other folks that were used during the creation of this document: +- [Chi Liu](https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2) +- [Sunsoon](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) + +## Before you begin... + +Nod-ai maintains the pipeline below, which allows us to take a ML model from e.g. huggingface, and compile it to a variety of devices including llvm-cpu, rocm and cuda and more as an optimized `vmfb` binary. + +1. The pipeline begins with a huggingface model, or some other supported source like llama.cpp. +2. [nod-ai/SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) takes a huggingface model and exports a `.mlir` file. +3. **[llvm/torch-mlir](https://github.com/llvm/torch-mlir)**, which you will be working on in turbine-camp, will lower torchscript, torch dialect, and torch aten ops further into a mixture `linalg` or `math` MLIR dialects (with occasionally other dialects in the mix) +4. [IREE](https://github.com/openxla/iree) converts the final `.mlir` file into a binary (typically `.vmfb`) for running on a device (llvm-cpu, rocm, vulcan, cuda, etc). + +The details of how we do it and helpful commands to help you set up each repo is in [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) + +PS: IREE is pronounced Eerie, and hence the ghost icon. + +## How to begin +1. You will start by adding support for 2 ops in torch-mlir, to get you familiar with the center of our pipeline. Begin by reading [torch-mlir's documentation on how to implement a new torch op](https://github.com/llvm/torch-mlir/blob/main/docs/Torch-ops-E2E-implementation.md), and set up `llvm/torch_mlir` using https://github.com/llvm/torch-mlir/blob/main/docs/development.md +2. Pick 1 of the yet-unimplemented from the following. You should choose something that looks easy to you. **Make sure you create an issue by clicking the little "target" icon to the right of the op, thereby marking the op as yours** + - [TorchToLinalg ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/347) + - [TorchOnnnxToTorch ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/215) +3. Implement it. For torch -> linalg, see the how to torchop section below. For Onnx ops, see how to onnx below. +5. Make a pull request and reference your issue. When the pull request is closed, also close your issue to mark the op as done + +
+ +### How to TorchToLinalg + +You will need to do 4 things: +- make sure the op exists in `torch_ods_gen.py`, and then run `build_tools/update_torch_ods.sh`, and then build. This generates `GeneratedTorchOps.td`, which is used to generate the cpp and h files where ops function signatures are defined. + - Reference [torch op registry](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/csrc/jit/passes/utils/op_registry.cpp#L21) +- make sure the op exists in `abstract_interp_lib_gen.py`, and then run `build_tools/update_abstract_interp_lib.sh`, and then build. This generates `AbstractInterpLib.cpp`, which is used to generate the cpp and h files where ops function signatures are defined. + - Reference [torch shape functions](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/jit/_shape_functions.py#L1311) +- write test cases. They live in `projects/pt1`. See the [Dec 2023 example](https://github.com/llvm/torch-mlir/pull/2640/files). +- implement the op in one of the `lib/Conversion/TorchToLinalg/*.cpp` files + +Reference Examples +- [A Dec 2023 example with the most up to date lowering](https://github.com/llvm/torch-mlir/pull/2640/files) +- [Chi's simple example of adding op lowering](https://github.com/llvm/torch-mlir/pull/1454) useful instructions and referring links for you to understand the op lowering pipeline in torch-mlir in the comments + +Resources: +- how to set up torch-mlir: [https://github.com/llvm/torch-mlir/blob/main/docs/development.md](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#checkout-and-build-from-source) +- torch-mlir doc on how to debug and test: [ttps://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing) +- [torch op registry](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/csrc/jit/passes/utils/op_registry.cpp#L21) +- [torch shape functions](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/jit/_shape_functions.py#L1311) + +### How to TorchOnnxToTorch +0. Generate the big folder of ONNX IR. Use https://github.com/llvm/torch-mlir/blob/main/test/python/onnx_importer/import_smoke_test.py . Alternatively, if you're trying to support a certain model, convert that model to onnx IR with + ``` + optimum-cli export onnx --model facebook/opt-125M fb-opt + python -m torch_mlir.tools.import_onnx fb-opt/model.onnx -o fb-opt-125m.onnx.mlir + ``` +2. Find an instance of the Op that you're trying to implement inside the smoke tests folder or the generated model IR, and write a test case. Later you will save it to one of the files in `torch-mlir/test/Conversion/TorchOnnxToTorch`, but for now feel free to put it anywhere. +3. Implement the op in `lib/Conversion/TorchOnnxToTorch/something.cpp`. +4. Test the conversion by running `./build/bin/torch-mlir-opt -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch your_mlir_file.mlir`. For more details, see https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing . Xida usually creates a separate MLIR file to test it to his satisfaction before integrating it into one of the files at `torch-mlir/test/Conversion/TorchOnnxToTorch`. + +Helpful examples: +- [A Dec 2023 example where an ONNX op is implemented](https://github.com/llvm/torch-mlir/pull/2641/files#diff-b584b152020af6d2e5dbf62a08b2f25ed5afc2c299228383b9651d22d44b5af4R493) +- [Vivek's example of ONNX op lowering](https://github.com/llvm/torch-mlir/commit/dc9ea08db5ac295b4b3f91fc776fef6a702900b9) + +## List of Tools you may need to use (this will be incorporated into the above instructions later) + +- Generate FILECHECK tests from MLIR test cases: `torch-mlir-opt -convert- /tmp/your_awesome_testcase.mlir | externals/llvm-project/mlir/utils/generate-test-checks.py +`. Please don't just paste the generated tests - reference them to write your own + +## Contacts +People who've worked on this for a while +- Vivek (@vivek97 on discord) +- Chi.Liu@amd.com + +Recent Turbine Camp Attendees, from recent to less recent +- Xida.ren@amd.com (@xida_ren on discord) +- Sungsoon.Cho@amd.com + +## Links + +- Tutorials + - [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) + - This document contains commands that would help you set up shark and run demos + - [How to implement ONNX op lowering](https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md) +- Examples + - [A Dec 2023 example with the most up to date lowering](https://github.com/llvm/torch-mlir/pull/2640/files) + - Chi's Example Lowering + - Github issue and code detailing how to implement the lowring of an OP. + - [Chi's simple example of adding op lowering](https://github.com/llvm/torch-mlir/pull/1454) useful instructions and referring links for you to understand the op lowering pipeline in torch-mlir in the comments + - If you have questions, reach out to [Chi on Discord](https://discordapp.com/channels/973663919757492264/1104195883307892837/1180233875058868224) + - [Vivek's example of ONNX op lowering](https://github.com/llvm/torch-mlir/commit/dc9ea08db5ac295b4b3f91fc776fef6a702900b9) +- Find Ops To Lower + - [Torch MLIR + ONNX Unimplemented Ops on Sharepoint](https://amdcloud-my.sharepoint.com/:x:/r/personal/esaimana_amd_com/Documents/Torch%20MLIR%20+%20ONNX%20Unimplemented%20Ops.xlsx?d=w438f26fac8fd44eeafb89bc99e2c563b&csf=1&web=1&e=Qd4eHm) + - If you don't have access yet, request it. + - nod-ai/SHARK-Turbine ssues tracking op support + - [Model and Op Support](https://github.com/nod-ai/SHARK-Turbine/issues/119) + - [ONNX op support](https://github.com/nod-ai/SHARK-Turbine/issues/215) + + +## Chi's useful commands for debugging torch mlir + +https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2 + +## How to write test cases and test your new op + +https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing + + + +## How to set up vs code and intellisence for [torch-mlir] +Xida: This is optional. If you're using VS code like me, you might want to set it up so you can use the jump to definition / references, auto fix, and other features. + +Feel free to contact me on discord if you have trouble figuring this out. + +You may need to write something like this into your + +```.vscode/settings.json``` + +under `torch-mlir` + +```json +{ + "files.associations": { + "*.inc": "cpp", + "ranges": "cpp", + "regex": "cpp", + "functional": "cpp", + "chrono": "cpp", + "__functional_03": "cpp", + "target": "cpp" + }, + "cmake.sourceDirectory": ["/home/xida/torch-mlir/externals/llvm-project/llvm"], + "cmake.buildDirectory": "${workspaceFolder}/build", + "cmake.generator": "Ninja", + "cmake.configureArgs": [ + "-DLLVM_ENABLE_PROJECTS=mlir", + "-DLLVM_EXTERNAL_PROJECTS=\"torch-mlir\"", + "-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=\"/home/xida/torch-mlir\"", + "-DCMAKE_BUILD_TYPE=Release", + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DLLVM_ENABLE_PROJECTS=mlir", + "-DLLVM_EXTERNAL_PROJECTS=torch-mlir", + "-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=${workspaceFolder}", + "-DMLIR_ENABLE_BINDINGS_PYTHON=ON", + "-DLLVM_ENABLE_ASSERTIONS=ON", + "-DLLVM_TARGETS_TO_BUILD=host", + ], + "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools", + "cmake.configureEnvironment": { + "PATH": "/home/xida/miniconda/envs/torch-mlir/bin:/home/xida/miniconda/condabin:/home/xida/miniconda/bin:/home/xida/miniconda/bin:/home/xida/miniconda/condabin:/home/xida/miniconda/bin:/home/xida/miniconda/bin:/home/xida/miniconda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin" + }, + "cmake.cmakePath": "/home/xida/miniconda/envs/torch-mlir/bin/cmake", // make sure this is a cmake that knows where your python is +} +``` +The important things to note are the `cmake.configureArgs`, which specify the location of your torch mlir, and the `cmake.sourceDirectory`, which indicates that CMAKE should not build from the current directory and should instead build from `externals/llvm-project/llvm` diff --git a/docs/adding_abstract_interpretation_functions.md b/docs/adding_abstract_interpretation_functions.md index b5e427e1adfd..eeebb9c315fa 100644 --- a/docs/adding_abstract_interpretation_functions.md +++ b/docs/adding_abstract_interpretation_functions.md @@ -4,7 +4,7 @@ As part of adding support for a Torch operator in Torch-MLIR, it is usually necessary to define a shape and dtype function so that the compiler can infer -the shapes and dtypes of result tensors for the operator. We use the +the shapes and dtypes of result tensors for the operator. We use the [abstract interpretation library](abstract_interp_lib.md) for this process. ## Step-by-step guide @@ -19,7 +19,7 @@ We will use the example of adding support for the `torch.aten.tanh` op. file is the "rosetta stone" that allows translating between e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype function signatures are: - + - `def aten〇tanh〡shape(self: List[int]) -> List[int]:` - `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:` @@ -39,10 +39,10 @@ We will use the example of adding support for the `torch.aten.tanh` op. But in general, you will need to write the function and test it (see the comments about "Shape, dtype, and decomposition function testing infrastructure" in `testing_framework.py`). New shape - functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889), - though it can be useful to iterate locally in `abstract_interp_lib_gen.py` + functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889), + though it can be useful to iterate locally in `abstract_interp_lib_gen.py` first. - + Similarly, dtype functions should ideally just be a call to the helper `promote_dtypes` defined in `library_generator.py`. However, some ops will require some extra logic to calculate the right result types. While dtypes diff --git a/docs/adding_an_e2e_test.md b/docs/adding_an_e2e_test.md index 7b74b904a0f8..91eee0520f56 100644 --- a/docs/adding_an_e2e_test.md +++ b/docs/adding_an_e2e_test.md @@ -87,7 +87,7 @@ following order: 1. Shape of input tensor. Use `-1` for dynamic dimensions 2. Dtype of the input tensor -3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h#L54-L67). This +3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/main/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.h#L54-L67). This will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the IR to eventually have value semantics. diff --git a/docs/architecture.md b/docs/architecture.md index 8ee6bfda8a0a..e2ef378bd99c 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow dynamic shapes for the trace. Additionally, the Python-level data structures and APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we consider both of those as the same from the perspective of the responsibilities -of the compiler. Both are accessed via the `torch_mlir.compile` Python API. +of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API. ### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript @@ -442,5 +442,5 @@ characteristics. ### Presentations and Talks -* 2021-10-07: MLIR ODM: Introduction to Torch-MLIR. ([recording](https://www.youtube.com/watch?v=QbNkex-gizs) and [slides](https://docs.google.com/presentation/d/1ZhzfE4EK6XV7AdQTYicrsE_OYjkER_yiB0vBeszRfzY/edit#slide=id.gf56404f79c_1_55)) -* 2022-08-20: Overview of Torch-MLIR passes. ([recording](https://www.youtube.com/watch?v=ZpwlVxsD9_U) and [slides](https://drive.google.com/file/d/1ZSlk1HGttRuVhJSxtP6spWt_hxClit2T/view)) +* 2021-10-07: MLIR ODM: Introduction to Torch-MLIR. ([recording](https://www.youtube.com/watch?v=QbNkex-gizs) and [slides](https://docs.google.com/presentation/d/1ZhzfE4EK6XV7AdQTYicrsE_OYjkER_yiB0vBeszRfzY/edit#slide=id.gf56404f79c_1_55)) +* 2022-08-20: Overview of Torch-MLIR passes. ([recording](https://www.youtube.com/watch?v=ZpwlVxsD9_U) and [slides](https://drive.google.com/file/d/1ZSlk1HGttRuVhJSxtP6spWt_hxClit2T/view)) diff --git a/docs/development.md b/docs/development.md index c60312e7ac5e..27ff8b7bfaad 100644 --- a/docs/development.md +++ b/docs/development.md @@ -5,9 +5,12 @@ ```shell git clone https://github.com/llvm/torch-mlir cd torch-mlir -git submodule update --init +git submodule update --init --progress ``` +Optionally, use `--depth=1` to make a shallow clone of the submodules. +While this is running, you can already setup the Python venv and dependencies in the next step. + ## Setup your Python VirtualEnvironment and Dependencies Also, ensure that you have the appropriate `python-dev` package installed @@ -42,12 +45,12 @@ cmake -GNinja -Bbuild \ -DLLVM_TARGETS_TO_BUILD=host \ externals/llvm-project/llvm ``` -The following additional quality of life flags can be used to reduce build time: +#### Flags that can reduce build time: * Enabling clang on Linux ```shell -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ ``` -* Enabling ccache: +* Enabling ccache ```shell -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ``` @@ -69,6 +72,14 @@ By default we download the latest version of libtorch. We have an experimental p -DLIBTORCH_VARIANT=shared # Set the variant of libtorch to build / link against. (`shared`|`static` and optionally `cxxabi11`) ``` +#### Flags to enable MLIR debugging: + +* Enabling `--debug` and `--debug-only` flags (see [MLIR docs](https://mlir.llvm.org/getting_started/Debugging/)) for the `torch-mlir-opt` tool +```shell + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ # or =Debug + -DLLVM_ENABLE_ASSERTIONS=ON \ +``` + ### Building against a pre-built LLVM If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`, you can also build the project *out-of-tree* using the following command as template: @@ -109,37 +120,50 @@ cmake --build build ### Linux and macOS ```shell -export PYTHONPATH=`pwd`/build/python_packages/torch_mlir:`pwd`/projects/pt1/examples +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer ``` ### Windows PowerShell ```shell -$env:PYTHONPATH = "$PWD/build/python_packages/torch_mlir;$PWD/projects/pt1/examples" +$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer" ``` ## Testing MLIR output in various dialects -To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. +To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`. Make sure you have activated the virtualenv and set the `PYTHONPATH` above (if running on Windows, modify the environment variable as shown above): ```shell source mlir_venv/bin/activate -export PYTHONPATH=`pwd`/build/tpython_packages/torch_mlir:`pwd`/projects/pt1/examples +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer +python test/python/fx_importer/basic_test.py +``` + +This will display the basic example in TORCH dialect. + +To test the compiler's output to the different MLIR dialects, you can also use the deprecated path +using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. +This path doesn't give access to the current generation work that is being driven via the fx_importer +and may lead to errors. + +Same as above, but with different python path and example: +```shell +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples python projects/pt1/examples/torchscript_resnet18_all_output_types.py ``` This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA. -The main functionality is on `torch_mlir.compile()`'s `output_type`. +The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`. Ex: ```python -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") +module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") ``` -Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. +`output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. ## Jupyter diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md b/docs/importers/onnx_importer.md similarity index 78% rename from include/torch-mlir/Conversion/TorchOnnxToTorch/README.md rename to docs/importers/onnx_importer.md index 6de1cc923411..a0b861d6d9cb 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md +++ b/docs/importers/onnx_importer.md @@ -3,19 +3,16 @@ We enable the direct representation of many ONNX features directly in the `torch` dialect as `torch.operator` custom ops with names like `onnx.{OperatorName}`. The majority of ONNX operators are represented -with a systematic transformation. See -[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) -for the reference importer which complies with the rules below -(this is planned to be upstreamed to torch-mlir proper in the near -future). +with a systematic transformation. `torch_mlir.extras.onnx_importer` +for the reference importer which complies with the rules below. ## Adding new ONNX operators With the exception of certain special or complicated ONNX operators, most are relatively straight-forward to map, following this general procedure: -* Plan the ops you wish to support by consulting the - [ONNX operator database](https://onnx.ai/onnx/operators/). +* Plan the ops you wish to support by consulting the + [ONNX operator database](https://onnx.ai/onnx/operators/). * This database has detailed diffs wrt different support versions but at the level of detail we operate, most version diffs are inconsequential and just require a bit more pattern support. @@ -26,23 +23,36 @@ are relatively straight-forward to map, following this general procedure: * Open the corresponding implementation file `DefaultDomainXtoY.cpp` corresponding with the alphabetic sort of the op and add a conversion. * Generate successful test cases: - * Either run the Turbine importer to produce MLIR output for all - ops/models in the ONNX test suite or use a dump that someone has - generated: - * [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) + * All `onnx_importer.py` tests are dumped to the test temp dir (success + or failure). This is typically located under + `tools/torch-mlir/test/python/onnx_importer/Output`. The `.mlir` files + under there should provide good variants to drive lit test coverage of + conversion. + * (Optionally) If there is an Onnx file that uses the op of interest, + convert that file to Onnx MLIR form using the following Python command, + `python -m torch_mlir.tools.import_onnx my_model.onnx`. * There are often many variants of tests for checking conformance of different historic ONNX encodings, but these are often not load bearing at the MLIR level. - * Pick a handful of test cases and add them to - `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an - alphabetic breakdown. At this time, ignore tests that are not exercising + * Pick a handful of test cases and add them to + `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to + an alphabetic breakdown. At this time, ignore tests that are not exercising useful differences in the pattern implementations. + * (Optionally) Use `torch-mlir-opt` to validate the outputs of the new op. + First, build the project using + `cmake --build build --target tools/torch-mlir/all`. This will generate + the conversion binary, `torch-mlir-opt`. Then call `torch-mlir-opt` with + the MLIR pass `convert-torch-onnx-to-torch`: + ``` + build/bin/torch-mlir-opt -convert-torch-onnx-to-torch \ + -split-input-file [DESIRED_ONNX_FILE].mlir + ``` * Generate failure test cases: * Some ops have forms that do not (easily) map to torch-mlir. If you leave an op under-implemented, add a failing test case to `test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`. -* Optional but recommended: Use your test case files to fuzz against the - torch-mlir backend of your choice by running a backend conversion pipeline +* Optional but recommended: Use your test case files to fuzz against the + torch-mlir backend of your choice by running a backend conversion pipeline and fixing any crashes/issues. * Send a patch with your changes. @@ -105,7 +115,7 @@ not yet implemented. The `IsolatedFromAbove` parent of the ops can contain the following metadata: -* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to +* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to `ModelProto.ir_version`. * `torch.onnx_meta.producer_name`: `StringAttr` corresponding to `ModelProto.producer_name`. @@ -125,7 +135,7 @@ are only minor variations of an op. Major variations should use ### Special op forms -Certain ONNX operators map to different structural components of +Certain ONNX operators map to different structural components of torch-mlir's representation: * `ConstantOfShape`: Mapped to `torch.vtensor.literal` with diff --git a/docs/ltc_backend.md b/docs/ltc_backend.md index b0177542899b..d047bbf9d812 100644 --- a/docs/ltc_backend.md +++ b/docs/ltc_backend.md @@ -103,7 +103,7 @@ At some point, the tensors will be synced in order to execute the computation -- >>> torch._lazy.mark_step() ``` -This triggers a call to `LazyGraphExecutor::SyncLiveTensorsGraph` somewhere in the guts of LTC, which collects all the `TorchMlirNode`s (technically `torch::lazy::Node`s at this point) from the current trace and +This triggers a call to `LazyGraphExecutor::SyncLiveTensorsGraph` somewhere in the guts of LTC, which collects all the `TorchMlirNode`s (technically `torch::lazy::Node`s at this point) from the current trace and creates an instance of `TorchMlirLoweringContext`. Here, the `TorchMlirNode`s are lowered to JIT via `mlir_node_lowering.cpp` and inserted into a `jit::Graph`. Next, `TorchMlirLoweringContext::Build` is executed and the final `jit::Graph` is sent to `torch_mlir::importJitFunctionAsFuncOp` to generate MLIR using the existing infrastructure from Torch-MLIR. @@ -121,7 +121,7 @@ Finally, the compiled computation is sent to `TorchMlirBackendImpl::ExecuteCompu ## Implementing a custom backend -A reference implementation of a custom backend is available [here](../python/torch_mlir/csrc/reference_lazy_backend/). +A reference implementation of a custom backend is available [here](../python/torch_mlir/csrc/reference_lazy_backend/). All the work involved with generating MLIR is handled in the base LTC backend, so vendors only need to worry about implementing `Compile`, `ExecuteComputation`, and some other minor methods to interface with the device. A pybind is needed to invoke C++ code to register the autogen PyTorch kernels and the custom backend itself. diff --git a/docs/ltc_examples.md b/docs/ltc_examples.md index b9306edce492..217761a51ebd 100644 --- a/docs/ltc_examples.md +++ b/docs/ltc_examples.md @@ -33,18 +33,18 @@ Received 1 arguments, and returned 2 results during ExecuteCompile! Results: tensor([[0.7616, 0.9640, 0.9951, 0.9993, 0.9999]], device='lazy:0') -JIT Graph: +JIT Graph: graph(%p0 : Float(1, 5)): %1 : Float(1, 5) = aten::tanh(%p0) return (%p0, %1) -MLIR: +MLIR: func.func @graph(%arg0: !torch.vtensor<[1,5],f32>) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32>) { %0 = torch.aten.tanh %arg0 : !torch.vtensor<[1,5],f32> -> !torch.vtensor<[1,5],f32> return %arg0, %0 : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32> } -Input/Output Alias Mapping: +Input/Output Alias Mapping: Output: 0 -> Input param: 0 In Mark Step: true diff --git a/docs/long_term_roadmap.md b/docs/roadmap.md similarity index 94% rename from docs/long_term_roadmap.md rename to docs/roadmap.md index 0f0940efc32d..f60502a52423 100644 --- a/docs/long_term_roadmap.md +++ b/docs/roadmap.md @@ -51,6 +51,22 @@ the ecosystem are: Most of this document describes long-term ecosystem changes that will address these, drastically improving Torch-MLIR's ability to meet its goals. +## Current API Paths + +Currently, there are two main API paths for the torch-mlir project: + +- The first path is part of the legacy project pt1 code + (torch_mlir.torchscript.compile). This allows users to test the compiler's + output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`, + `RAW` and `STABLEHLO`). This path is deprecated and doesn’t give access to + the current generation work that is being driven via the fx_importer. It is + tied to the old Torchscript path. +- The second path (torch_mlir.fx.export_and_import) allows users to import a + consolidated torch.export.ExportedProgram instance of an arbitrary Python + callable (an nn.Module, a function or a method) and output to torch dialect + mlir module. This path is aligned with PyTorch's roadmap, but the path is + not fully functional yet. + ## Roadmap ### Refactoring the frontend diff --git a/externals/llvm-project b/externals/llvm-project index b1e618b941d7..a8d87d17943a 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit b1e618b941d715115e69b152e8eebc649f4544c3 +Subproject commit a8d87d17943a1c5e76bd1878db99670dc7594453 diff --git a/externals/stablehlo b/externals/stablehlo index ab709fe48de8..4ac26f8786d4 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit ab709fe48de88c67717abfbd7ef17425eb95ddaf +Subproject commit 4ac26f8786d491c5d8376e6e563d1b72af09de75 diff --git a/include/torch-mlir-c/Dialects.h b/include/torch-mlir-c/Dialects.h index 99156c17009c..60f6ec1e5e26 100644 --- a/include/torch-mlir-c/Dialects.h +++ b/include/torch-mlir-c/Dialects.h @@ -22,4 +22,4 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Torch, torch); } #endif -#endif // TORCHMLIR_C_DIALECTS_H +#endif // TORCHMLIR_C_DIALECTS_H diff --git a/include/torch-mlir-c/Registration.h b/include/torch-mlir-c/Registration.h index 4d582e61f132..7d607693d56b 100644 --- a/include/torch-mlir-c/Registration.h +++ b/include/torch-mlir-c/Registration.h @@ -23,7 +23,7 @@ extern "C" { MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context); /** Registers all passes for symbolic access with the global registry. */ -MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(); +MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(void); #ifdef __cplusplus } diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index c852dd61387d..b214e147d5d9 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -35,7 +35,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className); /// Gets the !torch.nn.Module typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.optional type. @@ -53,7 +53,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchOptionalTypeGetContained(MlirType containedType); /// Gets the !torch.optional typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.tuple type. @@ -75,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos); /// Gets the !torch.tuple typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.union type. @@ -97,7 +97,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos); /// Gets the !torch.union typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.list type. @@ -113,7 +113,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType); MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t); /// Gets the !torch.list typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.Device type. @@ -126,7 +126,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context); /// Gets the !torch.device typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.Generator type. @@ -139,7 +139,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context); /// Gets the !torch.generator typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.bool type. @@ -152,7 +152,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context); /// Gets the !torch.bool typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.int type. @@ -165,7 +165,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context); /// Gets the !torch.int typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.float type. @@ -178,7 +178,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context); /// Gets the !torch.float typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.LinearParams type. @@ -192,7 +192,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context); /// Gets the !torch.linearparams typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.qint8 type. @@ -205,7 +205,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); /// Gets the !torch.qint8 typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.quint8 type. @@ -218,7 +218,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); /// Gets the !torch.quint8 typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.tensor type. @@ -266,7 +266,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); /// Gets the !torch.tensor typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.vtensor type. @@ -312,7 +312,7 @@ torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t); /// Gets the !torch.vtensor typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.none type. @@ -325,7 +325,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context); /// Gets the !torch.none typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.str type. @@ -338,7 +338,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context); /// Gets the !torch.str typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.any type. @@ -351,7 +351,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context); /// Gets the !torch.any typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.number type. @@ -364,7 +364,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context); /// Gets the !torch.number typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.dict type. @@ -387,7 +387,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t); /// Gets the !torch.dict typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID(void); #ifdef __cplusplus } diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h index f16b436c8790..159bcea7899e 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h @@ -10,9 +10,9 @@ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index ac2c114ded74..12a74faa44d3 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -137,6 +137,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", let arguments = (ins Variadic:$inputs, Variadic:$outputs, + DenseI64ArrayAttr:$dimension_map, DefaultValuedAttr:$unique_indices ); let results = (outs Variadic:$results); @@ -313,9 +314,6 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", int64_t getOutputRank() { return getOutputType().getRank(); } - int64_t getIterationDomainRank() { - return 2; - }; // Method to implement for specifying output range for // DestinationStyleOpInterface std::pair getDpsInitsPositionRange() { diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 3a130f472b3b..ed58c699559c 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -105,6 +105,15 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; } +def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> { + let summary = "Convert Torch ops to the Tensor dialect"; + let description = [{ + Converts any `Torch` operators that were expressible as `Tensor` dialect + operations. + }]; + let constructor = "mlir::torch::createConvertTorchToTensorPass()"; +} + def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let summary = "Convert Torch ops to TOSA ops"; let description = [{ diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 5b144503c0ec..261b4df3bd09 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -33,6 +33,8 @@ struct OpBinder { Location getLoc() { return op->getLoc(); } + int getNumOperands() { return op->getNumOperands(); } + // Operand matches of different arities. ParseResult tensorOperand(Value &value0) { if (op->getNumOperands() != 1) @@ -54,6 +56,20 @@ struct OpBinder { return success(); } + ParseResult tensorOperands(SmallVector &valueList, + int64_t numOperands) { + if (op->getNumOperands() != numOperands) + return failure(); + for (int64_t i = 0; i < numOperands; i++) { + Value curr = op->getOperand(i); + if (!toValidTensorType(curr.getType())) { + return failure(); + } + valueList.push_back(curr); + } + return success(); + } + ParseResult tensorOperandAtIndex(Value &valueIdx, int64_t idx) { if (idx >= op->getNumOperands()) return failure(); @@ -63,6 +79,13 @@ struct OpBinder { return success(); } + ParseResult tensorOperandsList(llvm::SmallVectorImpl &values) { + for (uint32_t i = 0; i < op->getNumOperands(); i++) { + values.push_back(op->getOperand(i)); + } + return success(); + } + // Result type matchers of different arities. ParseResult tensorResultType(Torch::ValueTensorType &type0) { if (op->getNumResults() != 1) @@ -74,6 +97,17 @@ struct OpBinder { return success(); } + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, + int64_t idx) { + if (idx >= op->getNumResults()) + return failure(); + auto t = toValidTensorType(op->getResult(idx).getType()); + if (!t) + return failure(); + typeIdx = t; + return success(); + } + // Attribute accessors. ParseResult s64BoolAttr(bool &value, StringRef nameSuffix, bool defaultValue = false) { @@ -113,8 +147,65 @@ struct OpBinder { return failure(); } + ParseResult f32FloatAttr(float &value, StringRef nameSuffix, + float defaultValue = 0.0f) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto floatAttr = dyn_cast(attr)) { + FloatType t = cast(floatAttr.getType()); + if (t.getWidth() != 32) + return failure(); + value = floatAttr.getValue().convertToFloat(); + return success(); + } + return failure(); + } + + ParseResult s64IntegerArrayAttr(llvm::SmallVector &values, + StringRef nameSuffix, + ArrayRef defaults) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + values.append(defaults.begin(), defaults.end()); + return success(); + } + if (auto arrayAttr = dyn_cast(attr)) { + for (auto element : arrayAttr) { + auto integerAttr = element.dyn_cast(); + if (!integerAttr) + return failure(); + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + values.push_back(integerAttr.getSInt()); + } + return success(); + } + return failure(); + } + + ParseResult denseElementsAttr(ElementsAttr elementsattr, + StringRef nameSuffix) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + Attribute attr = op->getAttr(name); + if (!attr || !isa(attr)) { + return failure(); + } + + elementsattr = cast(attr); + return success(); + } + ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, - std::string defaultValue = "") { + std::string defaultValue = "") { SmallString<64> name("torch.onnx."); name.append(nameSuffix); auto attr = op->getAttr(name); @@ -160,7 +251,10 @@ class OnnxCustomOpConversionPattern OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix, int64_t domainVersion) : OpConversionPattern(context), domainPrefix(std::move(domainPrefix)), - domainVersion(domainVersion) {} + domainVersion(domainVersion) { + // Onnx lowerings could produce other Onnx operations during the rewrite. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h new file mode 100644 index 000000000000..afc14a95ef13 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -0,0 +1,25 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H +#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +namespace mlir::torch::onnx_c { + +Value createConstantIntList(OpBinder binder, + ConversionPatternRewriter &rewriter, + SmallVector cstInput); + +Type getQTorchTypeFromTorchIntType(Type ty); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 134fbeca46dc..5d2095f04f14 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -36,7 +36,7 @@ Value getZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input, // padding value is zero. Value getDynamicZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &padding, - int unpaddedDims = 0); + int unpaddedDims = 0, Value pad = {}); // Helper function to caculate the output tensor dims for convolution-like ops. // Along each dim: @@ -95,6 +95,8 @@ FailureOr getBackendTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt); +bool isUnsignedTorchType(Type type); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h b/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h new file mode 100644 index 000000000000..9dd5a65429ed --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H +#define TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> createConvertTorchToTensorPass(); +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index c1b355e3c50d..44b9bbdde3b2 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -37,33 +37,31 @@ TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, return CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); } -// This specialization is for Div op. Unlike other binary ops, it doesn't support -// floating type. +// This specialization is for Div op. Unlike other binary ops, it doesn't +// support floating type. template <> tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs); std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, - Operation *op, - Value params_value, - Value index_value, - int32_t axis); + Operation *op, + Value params_value, + Value index_value, + int32_t axis); // Lowers torch.aten.Gather operators to a sequence of TOSA ops. // Revised from // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc -std::optional convertGatherNdOp(PatternRewriter &rewriter, - Operation *op, Type out_type, - Value params_value, - Value indices_value); +std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, + Type out_type, Value params_value, + Value indices_value); std::optional convertScatterNdOp(PatternRewriter &rewriter, Operation *op, Type outType, Value paramsValue, Value indicesValue, Value fillValues); - // Lowers ReduceAll to a sequence of TOSA ops. std::optional convertReduceAllOp(PatternRewriter &rewriter, Operation *op, @@ -106,6 +104,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims); +// Lowers LinalgVectorNorm to a sequence of TOSA ops. +std::optional +convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, + RankedTensorType output_type, Value input_value, + ElementsAttr axes_elems, bool keep_dims); + } // namespace tosa } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 5e6934001d7c..876b81092ae9 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -71,7 +71,7 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); // op. This allows shape inference during the framework to TOSA lowering. template TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, - Args &&... args) { + Args &&...args) { auto op = rewriter.create(loc, result_ty, args...); InferShapedTypeOpInterface shapeInterface = @@ -115,7 +115,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, template void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, - Type result_ty, Args &&... args) { + Type result_ty, Args &&...args) { auto result = CreateOpAndInfer(rewriter, op->getLoc(), result_ty, args...); rewriter.replaceOp(op, result->getResults()); @@ -130,7 +130,7 @@ TypedValue transposeBy(Location loc, // Get accumulator type for AvgPool2dOp. LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, - TypeAttr &accType); + TypeAttr &accType); } // namespace tosa } // namespace mlir diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 516954b88fbc..b76efe869a0f 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -88,7 +88,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype = std::nullopt, - std::optional dstOriginalDtype = std::nullopt); + std::optional dstOriginalDtype = std::nullopt, + std::optional originalScalar = std::nullopt); Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c000411e8e44..5b985a80b301 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -19,51 +19,6 @@ //===----------------------------------------------------------------------===// -def Torch_AtenTanhOp : Torch_Op<"aten.tanh", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::tanh : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenTanhOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenTanhOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::tanh_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenTanh_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenTanh_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenHardtanhOp : Torch_Op<"aten.hardtanh", [ AllowsTypeRefinement, HasValueSemantics, @@ -346,6 +301,51 @@ def Torch_AtenLog_Op : Torch_Op<"aten.log_", [ }]; } +def Torch_AtenSeluOp : Torch_Op<"aten.selu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::selu : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenSelu_Op : Torch_Op<"aten.selu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::selu_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSigmoidOp : Torch_Op<"aten.sigmoid", [ AllowsTypeRefinement, HasValueSemantics, @@ -436,6 +436,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [ }]; } +def Torch_AtenSinhOp : Torch_Op<"aten.sinh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sinh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSinhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSinhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenSinh_Op : Torch_Op<"aten.sinh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sinh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSinh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSinh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [ AllowsTypeRefinement, HasValueSemantics, @@ -751,12 +796,12 @@ def Torch_AtenSin_Op : Torch_Op<"aten.sin_", [ }]; } -def Torch_AtenExpOp : Torch_Op<"aten.exp", [ +def Torch_AtenAsinOp : Torch_Op<"aten.asin", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::exp : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::asin : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -765,20 +810,20 @@ def Torch_AtenExpOp : Torch_Op<"aten.exp", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenExpOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAsinOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenExpOp::print(OpAsmPrinter &printer) { + void AtenAsinOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ +def Torch_AtenAsin_Op : Torch_Op<"aten.asin_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::exp_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::asin_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -787,21 +832,21 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenExp_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAsin_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenExp_Op::print(OpAsmPrinter &printer) { + void AtenAsin_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [ +def Torch_AtenAsinhOp : Torch_Op<"aten.asinh", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::expm1 : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::asinh : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -810,20 +855,20 @@ def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenExpm1Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAsinhOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenExpm1Op::print(OpAsmPrinter &printer) { + void AtenAsinhOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenExpm1_Op : Torch_Op<"aten.expm1_", [ +def Torch_AtenAsinh_Op : Torch_Op<"aten.asinh_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::expm1_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::asinh_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -832,21 +877,21 @@ def Torch_AtenExpm1_Op : Torch_Op<"aten.expm1_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenExpm1_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAsinh_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenExpm1_Op::print(OpAsmPrinter &printer) { + void AtenAsinh_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenCosOp : Torch_Op<"aten.cos", [ +def Torch_AtenExpOp : Torch_Op<"aten.exp", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::cos : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::exp : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -855,20 +900,20 @@ def Torch_AtenCosOp : Torch_Op<"aten.cos", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCosOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenExpOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCosOp::print(OpAsmPrinter &printer) { + void AtenExpOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ +def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::cos_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::exp_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -877,21 +922,21 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCos_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenExp_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCos_Op::print(OpAsmPrinter &printer) { + void AtenExp_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ +def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::acos : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::expm1 : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -900,20 +945,20 @@ def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAcosOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenExpm1Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAcosOp::print(OpAsmPrinter &printer) { + void AtenExpm1Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ +def Torch_AtenExpm1_Op : Torch_Op<"aten.expm1_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::acos_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::expm1_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -922,21 +967,21 @@ def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAcos_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenExpm1_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAcos_Op::print(OpAsmPrinter &printer) { + void AtenExpm1_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ +def Torch_AtenCosOp : Torch_Op<"aten.cos", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::atan : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::cos : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -945,20 +990,20 @@ def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAtanOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenCosOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAtanOp::print(OpAsmPrinter &printer) { + void AtenCosOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAtan_Op : Torch_Op<"aten.atan_", [ +def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::atan_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::cos_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -967,68 +1012,66 @@ def Torch_AtenAtan_Op : Torch_Op<"aten.atan_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAtan_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenCos_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAtan_Op::print(OpAsmPrinter &printer) { + void AtenCos_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [ +def Torch_AtenCoshOp : Torch_Op<"aten.cosh", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::cosh : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenCoshOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAtan2Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenCoshOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [ +def Torch_AtenCosh_Op : Torch_Op<"aten.cosh_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::cosh_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenCosh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAtan2_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenCosh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAsinOp : Torch_Op<"aten.asin", [ +def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::asin : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::acos : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -1037,20 +1080,20 @@ def Torch_AtenAsinOp : Torch_Op<"aten.asin", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAsinOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAcosOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAsinOp::print(OpAsmPrinter &printer) { + void AtenAcosOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAsin_Op : Torch_Op<"aten.asin_", [ +def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::asin_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::acos_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -1059,21 +1102,21 @@ def Torch_AtenAsin_Op : Torch_Op<"aten.asin_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAsin_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAcos_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAsin_Op::print(OpAsmPrinter &printer) { + void AtenAcos_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenNegOp : Torch_Op<"aten.neg", [ +def Torch_AtenAcoshOp : Torch_Op<"aten.acosh", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::neg : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::acosh : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -1082,20 +1125,20 @@ def Torch_AtenNegOp : Torch_Op<"aten.neg", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNegOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAcoshOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNegOp::print(OpAsmPrinter &printer) { + void AtenAcoshOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ +def Torch_AtenAcosh_Op : Torch_Op<"aten.acosh_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::neg_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::acosh_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -1104,21 +1147,21 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNeg_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAcosh_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNeg_Op::print(OpAsmPrinter &printer) { + void AtenAcosh_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ +def Torch_AtenTanOp : Torch_Op<"aten.tan", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::tan : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -1127,20 +1170,20 @@ def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCeilOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenTanOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCeilOp::print(OpAsmPrinter &printer) { + void AtenTanOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ +def Torch_AtenTan_Op : Torch_Op<"aten.tan_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::tan_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -1149,21 +1192,21 @@ def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCeil_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenTan_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCeil_Op::print(OpAsmPrinter &printer) { + void AtenTan_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ +def Torch_AtenTanhOp : Torch_Op<"aten.tanh", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_not : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::tanh : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -1172,20 +1215,247 @@ def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseNotOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenTanhOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseNotOp::print(OpAsmPrinter &printer) { + void AtenTanhOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [ +def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_not_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::tanh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTanh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTanh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atan : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtan_Op : Torch_Op<"aten.atan_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::atan_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtan_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtan_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtanhOp : Torch_Op<"aten.atanh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atanh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtanh_Op : Torch_Op<"aten.atanh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::atanh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAtan2Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAtan2_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenNegOp : Torch_Op<"aten.neg", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::neg : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNegOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenNegOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::neg_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNeg_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenNeg_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_not : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseNotOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenBitwiseNotOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_not_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -1485,49 +1755,51 @@ def Torch_AtenLerp_TensorOp : Torch_Op<"aten.lerp_.Tensor", [ }]; } -def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ +def Torch_AtenLerpScalarOp : Torch_Op<"aten.lerp.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$end, + AnyTorchScalarType:$weight ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenLerpScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenEqTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenLerpScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ +def Torch_AtenLerp_ScalarOp : Torch_Op<"aten.lerp_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::lerp_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + Torch_NonValueTensorType:$end, + AnyTorchScalarType:$weight ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEq_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenLerp_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenEq_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenLerp_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } @@ -1814,12 +2086,12 @@ def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ }]; } -def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ +def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -1829,20 +2101,20 @@ def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFmodScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNeScalarOp::print(OpAsmPrinter &printer) { + void AtenFmodScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ +def Torch_AtenFmod_ScalarOp : Torch_Op<"aten.fmod_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::ne_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::fmod_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, AnyTorchScalarType:$other @@ -1852,638 +2124,628 @@ def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFmod_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNe_ScalarOp::print(OpAsmPrinter &printer) { + void AtenFmod_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenEqScalarOp : Torch_Op<"aten.eq.Scalar", [ +def Torch_AtenMaskedFillScalarOp : Torch_Op<"aten.masked_fill.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$mask, + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEqScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenMaskedFillScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenEqScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenMaskedFillScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenEq_ScalarOp : Torch_Op<"aten.eq_.Scalar", [ +def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::eq_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::masked_fill_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + Torch_NonValueTensorType:$mask, + AnyTorchScalarType:$value ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEq_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenMaskedFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenEq_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenMaskedFill_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGtScalarOp : Torch_Op<"aten.gt.Scalar", [ +def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalScalarType:$min, + AnyTorchOptionalScalarType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGtScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClampOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGtScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClampOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGt_ScalarOp : Torch_Op<"aten.gt_.Scalar", [ +def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::gt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_ : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalScalarType:$min, + AnyTorchOptionalScalarType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClamp_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGt_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClamp_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGeScalarOp : Torch_Op<"aten.ge.Scalar", [ +def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalTensorType:$min, + AnyTorchOptionalTensorType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGeScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGeScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClampTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ +def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::ge_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalNonValueTensorType:$min, + AnyTorchOptionalNonValueTensorType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGe_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClamp_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [ +def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$min ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLtScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMinOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLtScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMinOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [ +def Torch_AtenClampMin_Op : Torch_Op<"aten.clamp_min_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::lt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min_ : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$min ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMin_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLt_ScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMin_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLeScalarOp : Torch_Op<"aten.le.Scalar", [ +def Torch_AtenClampMinTensorOp : Torch_Op<"aten.clamp_min.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::le.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$min ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMinTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLeScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMinTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLe_ScalarOp : Torch_Op<"aten.le_.Scalar", [ +def Torch_AtenClampMin_TensorOp : Torch_Op<"aten.clamp_min_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::le_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + Torch_NonValueTensorType:$min ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMin_TensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLe_ScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMin_TensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [ +def Torch_AtenClampMaxOp : Torch_Op<"aten.clamp_max", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFmodScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFmodScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMaxOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenFmod_ScalarOp : Torch_Op<"aten.fmod_.Scalar", [ +def Torch_AtenClampMax_Op : Torch_Op<"aten.clamp_max_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::fmod_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max_ : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFmod_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMax_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFmod_ScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMax_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMaskedFillScalarOp : Torch_Op<"aten.masked_fill.Scalar", [ +def Torch_AtenClampMaxTensorOp : Torch_Op<"aten.clamp_max.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$mask, - AnyTorchScalarType:$value + AnyTorchTensorType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaskedFillScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenClampMaxTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMaskedFillScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenClampMaxTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ +def Torch_AtenClampMax_TensorOp : Torch_Op<"aten.clamp_max_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::masked_fill_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$mask, - AnyTorchScalarType:$value + Torch_NonValueTensorType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaskedFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenClampMax_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMaskedFill_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenClampMax_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ +def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; + let summary = "Generated op for `aten::log2 : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalScalarType:$min, - AnyTorchOptionalScalarType:$max + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog2Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog2Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [ +def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_ : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; + let summary = "Generated op for `aten::log2_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchOptionalScalarType:$min, - AnyTorchOptionalScalarType:$max + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClamp_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog2_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClamp_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog2_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [ +def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; + let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalTensorType:$min, - AnyTorchOptionalTensorType:$max + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog10Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [ +def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; + let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchOptionalNonValueTensorType:$min, - AnyTorchOptionalNonValueTensorType:$max + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClamp_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog10_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [ +def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_min : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::sqrt : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$min + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMinOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSqrtOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMinOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSqrtOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMin_Op : Torch_Op<"aten.clamp_min_", [ +def Torch_AtenSqrt_Op : Torch_Op<"aten.sqrt_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_min_ : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::sqrt_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$min + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMin_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSqrt_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMin_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSqrt_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMinTensorOp : Torch_Op<"aten.clamp_min.Tensor", [ +def Torch_AtenLog1pOp : Torch_Op<"aten.log1p", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::log1p : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$min + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMinTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenLog1pOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMinTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenLog1pOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMin_TensorOp : Torch_Op<"aten.clamp_min_.Tensor", [ +def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_min_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::log1p_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$min + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMin_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenLog1p_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMin_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenLog1p_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMaxOp : Torch_Op<"aten.clamp_max", [ +def Torch_AtenLogitOp : Torch_Op<"aten.logit", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_max : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::logit : (Tensor, float?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$max + AnyTorchOptionalFloatType:$eps ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMaxOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenLogitOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenClampMaxOp::print(OpAsmPrinter &printer) { + void AtenLogitOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenClampMax_Op : Torch_Op<"aten.clamp_max_", [ +def Torch_AtenLogit_Op : Torch_Op<"aten.logit_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_max_ : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::logit_ : (Tensor, float?) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$max + AnyTorchOptionalFloatType:$eps ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMax_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenLogit_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenClampMax_Op::print(OpAsmPrinter &printer) { + void AtenLogit_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenClampMaxTensorOp : Torch_Op<"aten.clamp_max.Tensor", [ +def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::rsqrt : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$max + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMaxTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMaxTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenRsqrtOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMax_TensorOp : Torch_Op<"aten.clamp_max_.Tensor", [ +def Torch_AtenRsqrt_Op : Torch_Op<"aten.rsqrt_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_max_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::rsqrt_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$max + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMax_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenRsqrt_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMax_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenRsqrt_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ +def Torch_AtenAbsOp : Torch_Op<"aten.abs", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::log2 : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::abs : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2492,20 +2754,20 @@ def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog2Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAbsOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog2Op::print(OpAsmPrinter &printer) { + void AtenAbsOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ +def Torch_AtenAbs_Op : Torch_Op<"aten.abs_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::log2_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::abs_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -2514,21 +2776,21 @@ def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog2_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAbs_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog2_Op::print(OpAsmPrinter &printer) { + void AtenAbs_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ +def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::reciprocal : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2537,20 +2799,20 @@ def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenReciprocalOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog10Op::print(OpAsmPrinter &printer) { + void AtenReciprocalOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ +def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::reciprocal_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -2559,448 +2821,223 @@ def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenReciprocal_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog10_Op::print(OpAsmPrinter &printer) { + void AtenReciprocal_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ +def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sqrt : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSqrtOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAndTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSqrtOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAndTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenSqrt_Op : Torch_Op<"aten.sqrt_", [ +def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::sqrt_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSqrt_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAnd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSqrt_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAnd_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLog1pOp : Torch_Op<"aten.log1p", [ +def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::log1p : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog1pOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLog1pOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [ +def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::log1p_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog1p_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLog1p_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ +def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::rsqrt : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseOrTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenRsqrtOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseOrTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenRsqrt_Op : Torch_Op<"aten.rsqrt_", [ +def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::rsqrt_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_or_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsqrt_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseOr_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenRsqrt_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseOr_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAbsOp : Torch_Op<"aten.abs", [ +def Torch_AtenBitwiseXorTensorOp : Torch_Op<"aten.bitwise_xor.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::abs : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAbsOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseXorTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAbsOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseXorTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAbs_Op : Torch_Op<"aten.abs_", [ +def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::abs_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_xor_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAbs_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseXor_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAbs_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseXor_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [ +def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::reciprocal : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenReciprocalOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenReciprocalOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::reciprocal_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenReciprocal_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenReciprocal_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseAndTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseAndTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::bitwise_and_.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseAnd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseAnd_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseOrTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseOrTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::bitwise_or_.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseOr_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseOr_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseXorTensorOp : Torch_Op<"aten.bitwise_xor.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseXorTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseXorTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::bitwise_xor_.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseXor_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseXor_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } @@ -3426,6 +3463,7 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3475,6 +3513,7 @@ def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3525,6 +3564,7 @@ def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3652,60 +3692,488 @@ def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ }]; } -def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ +def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEqTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEq_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEq_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenLeScalarOp : Torch_Op<"aten.le.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::le.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLeScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenLe_ScalarOp : Torch_Op<"aten.le_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::le_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLe_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLtScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLtScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::lt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLt_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenGtScalarOp : Torch_Op<"aten.gt.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGtScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGtScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenGt_ScalarOp : Torch_Op<"aten.gt_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::gt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGt_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenGeScalarOp : Torch_Op<"aten.ge.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGeScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::ge_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGe_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenEqScalarOp : Torch_Op<"aten.eq.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEqScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEqScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenEq_ScalarOp : Torch_Op<"aten.eq_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::eq_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEq_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEq_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenNeScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::ne_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenNe_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFloorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFloor_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenCeilOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenMulScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenCeilOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } -def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ +def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenCeil_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenCeil_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ +def Torch_AtenRoundOp : Torch_Op<"aten.round", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::round : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -3714,21 +4182,21 @@ def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenRoundOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenFloorOp::print(OpAsmPrinter &printer) { + void AtenRoundOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } -def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ +def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::round_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -3737,10 +4205,10 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenRound_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenFloor_Op::print(OpAsmPrinter &printer) { + void AtenRound_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; @@ -4604,6 +5072,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [ }]; } +def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$lambd, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenExponentialOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [ AllowsTypeRefinement, HasValueSemantics, @@ -4963,52 +5456,6 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [ }]; } -def Torch_AtenRoundOp : Torch_Op<"aten.round", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::round : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRoundOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenRoundOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; - let hasFolder = 1; -} - -def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::round_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRound_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenRound_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ AllowsTypeRefinement, HasValueSemantics, @@ -5287,6 +5734,35 @@ def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ }]; } +def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5316,6 +5792,35 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ }]; } +def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5406,6 +5911,61 @@ def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [ }]; } +def Torch_AtenConvTbcOp : Torch_Op<"aten.conv_tbc", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$weight, + AnyTorchTensorType:$bias, + Torch_IntType:$pad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTbcOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenConvTbcOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenConvTbcBackwardOp : Torch_Op<"aten.conv_tbc_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchTensorType:$bias, + Torch_IntType:$pad + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTbcBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 3); + } + void AtenConvTbcBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 3); + } + }]; +} + def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ AllowsTypeRefinement, HasValueSemantics, @@ -5653,6 +6213,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ AllowsTypeRefinement, HasValueSemantics, @@ -5670,17 +6261,45 @@ def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ Torch_FloatType:$eps ); let results = (outs - AnyTorchTensorType:$result0, - AnyTorchTensorType:$result1, - AnyTorchTensorType:$result2 + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 3); + } + void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 3); + } + }]; +} + +def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + Torch_IntType:$num_groups, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 8, 3); + ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 8, 3); + void AtenGroupNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } @@ -5713,6 +6332,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$p + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenNormScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [ AllowsTypeRefinement, HasValueSemantics, @@ -6508,6 +7152,31 @@ def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d }]; } +def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveMaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 2); + } + void AtenAdaptiveMaxPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 2); + } + }]; +} + def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ AllowsTypeRefinement, HasValueSemantics, @@ -7276,6 +7945,33 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ }]; } +def Torch_AtenLinalgNormOp : Torch_Op<"aten.linalg_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalScalarType:$ord, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenLinalgNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [ AllowsTypeRefinement, HasValueSemantics, @@ -7681,6 +8377,78 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ }]; } +def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::replication_pad2d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReplicationPad2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReplicationPad2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, @@ -7922,6 +8690,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasFolder = 1; } def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [ @@ -7977,6 +8746,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasFolder = 1; } def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ @@ -8086,6 +8856,7 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ printDefaultTorchOp(printer, *this, 4, 1); } }]; + let hasFolder = 1; } def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [ @@ -8188,6 +8959,7 @@ def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenIsnanOp : Torch_Op<"aten.isnan", [ @@ -8236,6 +9008,52 @@ def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [ }]; } +def Torch_AtenIsneginfOp : Torch_Op<"aten.isneginf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isneginf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsneginfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsneginfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenIsposinfOp : Torch_Op<"aten.isposinf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isposinf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsposinfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsposinfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAllOp : Torch_Op<"aten.all", [ AllowsTypeRefinement, HasValueSemantics, @@ -8562,6 +9380,29 @@ def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ }]; } +def Torch_AtenTraceOp : Torch_Op<"aten.trace", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::trace : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTraceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTraceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -8610,6 +9451,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [ @@ -9243,6 +10085,7 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [ @@ -9344,6 +10187,7 @@ def Torch_AtenItemOp : Torch_Op<"aten.item", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [ @@ -9606,6 +10450,7 @@ def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ @@ -10140,6 +10985,7 @@ def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [ @@ -10165,6 +11011,7 @@ def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [ @@ -10190,6 +11037,7 @@ def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ @@ -10215,6 +11063,33 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; +} + +def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalFloatType:$nan, + AnyTorchOptionalFloatType:$posinf, + AnyTorchOptionalFloatType:$neginf + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNanToNumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNanToNumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; } def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ @@ -10442,6 +11317,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [ @@ -10465,6 +11341,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ @@ -10686,6 +11563,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [ printDefaultTorchOp(printer, *this, 6, 1); } }]; + let hasFolder = 1; } def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ @@ -10877,33 +11755,59 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [ }]; } -def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ +def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + Torch_IntType:$steps, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenLinspaceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)`"; let arguments = (ins - AnyTorchScalarType:$start, - AnyTorchScalarType:$end, - Torch_IntType:$steps, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + Torch_IntType:$dim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenLinalgCrossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenLinspaceOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenLinalgCrossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasVerifier = 1; } def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ @@ -10978,6 +11882,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ }]; } +def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::diagonal : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$offset, + Torch_IntType:$dim1, + Torch_IntType:$dim2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDiagonalOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenDiagonalOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenDiagonalCopyOp : Torch_Op<"aten.diagonal_copy", [ AllowsTypeRefinement, HasValueSemantics, @@ -11563,6 +12492,33 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at }]; } +def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$grid, + Torch_IntType:$interpolation_mode, + Torch_IntType:$padding_mode, + Torch_BoolType:$align_corners + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGridSamplerOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenGridSamplerOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, @@ -11749,6 +12705,7 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenStackOp : Torch_Op<"aten.stack", [ @@ -12013,6 +12970,7 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [ printDefaultTorchOp(printer, *this, 3, 2); } }]; + let hasFolder = 1; } def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [ @@ -12063,6 +13021,30 @@ def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [ }]; } +def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::split.sizes : (Tensor, int[], int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$split_size, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSplitSizesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSplitSizesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ AllowsTypeRefinement, ReadOnly @@ -12244,6 +13226,29 @@ def Torch_AtenJoinOp : Torch_Op<"aten.join", [ }]; } +def Torch_AtenWarnOp : Torch_Op<"aten.warn", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::warn : (str, int) -> ()`"; + let arguments = (ins + Torch_StringType:$message, + Torch_IntType:$stacklevel + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenWarnOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 0); + } + void AtenWarnOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 0); + } + }]; +} + def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ AllowsTypeRefinement, HasValueSemantics, @@ -13348,6 +14353,31 @@ def Torch_Aten_SetItemTOp : Torch_Op<"aten._set_item.t", [ }]; } +def Torch_AtenMulOp : Torch_Op<"aten.mul", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul : (Scalar, Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a, + AnyTorchScalarType:$b + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenDivOp : Torch_Op<"aten.div", [ AllowsTypeRefinement, HasValueSemantics, @@ -14029,6 +15059,179 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$scales, + AnyTorchTensorType:$zero_points, + Torch_IntType:$axis, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$scale, + Torch_IntType:$zero_point, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenDequantizeSelfOp : Torch_Op<"aten.dequantize.self", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::dequantize.self : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDequantizeSelfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenDequantizeSelfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenDequantizeTensorOp : Torch_Op<"aten.dequantize.tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::dequantize.tensor : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$qtensor + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDequantizeTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenDequantizeTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIntReprOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_Aten_MakePerChannelQuantizedTensorOp : Torch_Op<"aten._make_per_channel_quantized_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$scale, + AnyTorchTensorType:$zero_point, + Torch_IntType:$axis + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_MakePerChannelQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void Aten_MakePerChannelQuantizedTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$scale, + Torch_IntType:$zero_point + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_MakePerTensorQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void Aten_MakePerTensorQuantizedTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, @@ -14161,6 +15364,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 64b70e097c39..e6a9e1622cc1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -294,6 +294,21 @@ bool isListPotentiallyMutated(Value list); /// the list. bool potentiallyMutatesListOperands(Operation *op); +/// Returns the value from an `IntegerAttr` as an `int64_t`. +/// +/// @param intAttr the `IntegerAttr` from which to extract the value +/// @return the value as an `int64_t` +/// +/// Regardless of the signed-ness of the attribute, this function returns +/// the value as a signed integer, which implies that if the attribute has +/// a 64-bit unsigned value, it will be converted to an int64_t in the manner +/// that uint64_t is cast to int64_t in C++. +inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) { + if (intAttr.getType().isUnsignedInteger()) + return intAttr.getValue().getZExtValue(); + return intAttr.getValue().getSExtValue(); +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index c86244f5f1e3..f5214db58f19 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -843,12 +843,23 @@ def Torch_OperatorOp : Torch_Op<"operator", [ let arguments = (ins StrAttr:$name, Variadic:$operands); let results = (outs Variadic:$results); + let regions = (region VariadicRegion:$regions); let assemblyFormat = [{ - $name `(` $operands `)` attr-dict `:` functional-type($operands, $results) + $name `(` $operands `)` attr-dict `:` functional-type($operands, $results) $regions }]; } +def Torch_OperatorTerminatorOp : Torch_Op<"operator_terminator", [Terminator, + HasParent<"::mlir::torch::Torch::OperatorOp">]> { + let summary = "Implicit terminator for torch.operator"; + + let arguments = (ins Variadic:$operands); + let results = (outs); + + let assemblyFormat = "$operands attr-dict `:` type($operands)"; +} + def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [ AllowsTypeRefinement, AllowedInModuleInitializer, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h index 20f1bc109885..271481f0ae8a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h @@ -36,8 +36,7 @@ class HasValueSemantics // This is a weaker form of HasValueSemantics, since that trait also requires no // aliasing. That is, HasValueSemantics implies this trait. template -class ReadOnly - : public ::mlir::OpTrait::TraitBase {}; +class ReadOnly : public ::mlir::OpTrait::TraitBase {}; // If a Torch op has this trait, it means that the op is a "trailing underscore" // op variant that performs an in-place operation on its first argument. These @@ -62,7 +61,8 @@ class AllowsTypeRefinement // by the IValue importer. template class AllowedInModuleInitializer - : public ::mlir::OpTrait::TraitBase {}; + : public ::mlir::OpTrait::TraitBase {}; } // namespace OpTrait } // namespace Torch diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index de77a1a8f8a3..c8d1c5051f28 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -32,7 +32,7 @@ class ValueTensorType; /// Common getter function signature that covers all tensor types. /// Used for sharing code between NonValueTensorType and ValueTensorType. using GetTensorTypeFn = llvm::function_ref>, Type)>; + MLIRContext *, std::optional>, Type, Attribute)>; /// The representation of an unknown dimension size in an ArrayRef. constexpr static int64_t kUnknownSize = -1; diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 1f7231b3500a..898c768ae1c2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -63,12 +63,13 @@ class AnyTorchTensorType ``` tensor-type ::= (`!torch.tensor` | `!torch.vtensor`) tensor-modifiers? - tensor-modifiers ::= `<` sizes-spec `,` dtype-spec `>` + tensor-modifiers ::= `<` sizes-spec `,` dtype-spec (',' sparsity)? `>` sizes-spec ::= `*` | `[` size-list `]` size-list ::= /*empty*/ | size-list-nonempty size-list-nonempty = size (`,` size)* size ::= `?` | decimal-literal dtype-spec ::= `unk` | type + sparsity ::= attribute-value ``` Represents a multi-dimensional array to model Torch's `torch.Tensor` type. @@ -133,6 +134,12 @@ class AnyTorchTensorType |-------------------|--------------------| ``` + The `sparsity` attribute directly mirrors the additional tensor `encoding` + defined by upstream MLIR on the RankedTensorType. Unlike the upstream + attribute, however, this attribute is exclusively used to denote a + straightforward tensor (with an empty attribute) or a sparse tensor + (with a SparseTensorEncodingAttr). + TODO: Support the full set of Torch dtypes. TODO: Use si1? @@ -149,8 +156,20 @@ class AnyTorchTensorType }]; let parameters = (ins OptionalArrayRefTorchParameter<"int64_t", "sizes of dimensions">:$optionalSizes, - "::mlir::Type":$optionalDtype + "::mlir::Type":$optionalDtype, + "Attribute":$optionalSparsity ); + let builders = [ + // Provide builder where optionalSparsity is empty by default. + TypeBuilder<(ins + "::std::optional>":$optionalSizes, + "::mlir::Type":$optionalDtype, + CArg<"Attribute", "{}">:$optionalSparsity + ), [{ + return $_get(context, optionalSizes, optionalDtype, optionalSparsity); + }]> + ]; + let skipDefaultBuilders = 1; let genVerifyDecl = 1; let hasCustomAssemblyFormat = 1; string extraBaseClassDeclaration = [{ @@ -306,6 +325,17 @@ def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { }]; } +def Torch_QInt32Type : Torch_Type<"QInt32", "qint32"> { + let summary = "Type modeling `ScalarType::QInt32`"; + let description = [{ + This is intended to be a 1:1 match for the Torch `ScalarType` types. + + Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of + types, it is deemed preferable to import them as one-off ad-hoc types + instead of a single parameterized type. + }]; +} + def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> { let summary = "Torch packed linear params type"; let description = [{ diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 84efddcc93d4..71111c00cd28 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -61,7 +61,8 @@ struct TorchLoweringPipelineOptions Option extraLibrary{ *this, "extra-library", - llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")}; + llvm::cl::desc("Filename of MLIR module for splicing into the abstract " + "interpretation library.")}; }; /// Creates a pipeline that lowers the object graph IR that is produced by @@ -106,6 +107,10 @@ createDecomposeComplexOpsPass(ArrayRef legalOps); std::unique_ptr> createRecomposeComplexOpsPass(); +std::unique_ptr> createFuseQuantizedOpsPass(); +std::unique_ptr> +createMatchQuantizedCustomOpsPass(); + std::unique_ptr> createReifyShapeCalculationsPass(StringRef extraLibrary); @@ -121,8 +126,7 @@ createSimplifyDtypeCalculationsPass(); std::unique_ptr> createDropAbstractInterpCalculationsPass(); -std::unique_ptr> -createEraseModuleInitializerPass(); +std::unique_ptr> createEraseModuleInitializerPass(); std::unique_ptr> createLowerToBackendContractPass(int maxIterations, bool decompose, diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 8967855c2e52..7b52d786610e 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -258,6 +258,34 @@ def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> { }]; } +def FuseQuantizedOps : Pass<"torch-fuse-quantized-ops", "func::FuncOp"> { + let summary = "QDQ: Fuse recognized QDQ op sequences."; + let constructor = "mlir::torch::Torch::createFuseQuantizedOpsPass()"; + let description = [{ + Torch models often represents quantized operations as the sequence: + Dequantize + DenseOp + Quantize + This allows the existing dense operations to be used without specifically + representing quantized types. It is more computationally efficient to + perform the dense operation in the quantized domain, so we fuse the + quantization / dequantization behavior together and represent as purely + quantized operations. + }]; +} + +def MatchQuantizedCustomOps : Pass<"torch-match-quantized-custom-ops", "func::FuncOp"> { + let summary = "Match quantized operations that occur in different namespace."; + let constructor = "mlir::torch::Torch::createMatchQuantizedCustomOpsPass()"; + let description = [{ + Torch quantization utilities generated custom op versions of known aten + quantziation operations. We can match these specially named operations and + rewrite to the corresponding aten quantized operations. + + We handle this post import to maintain a simplified import process. + }]; +} + def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> { let summary = "Reify shape calculations."; let constructor = [{ diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index efb114fbfa14..043dd92549b2 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -140,12 +140,7 @@ enum Reduction { None, Mean, Sum, END }; // Source: // https://github.com/pytorch/pytorch/blob/master/c10/core/MemoryFormat.h //===----------------------------------------------------------------------===// -enum MemoryFormat { - Contiguous, - Preserve, - ChannelsLast, - ChannelsLast3d -}; +enum MemoryFormat { Contiguous, Preserve, ChannelsLast, ChannelsLast3d }; //===----------------------------------------------------------------------===// // Possible values for `layout` argument in PyTorch ops that support it. diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 842c86defb74..44f977d5d0ed 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -11,6 +11,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" namespace mlir { @@ -131,6 +132,28 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter, Value opSize, Value opStride, Location loc); +// Helper to create a tensor filled with the given scalar. Scalar would be +// converted the to the element type of the given tensor type. +Value createInitTensor(PatternRewriter &rewriter, Location loc, + BaseTensorType resultType, Value scalar, Value sizeList); + +// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` +// would be converted to the element type of the given `inputType`. +Value createRank0Tensor(PatternRewriter &rewriter, Location loc, + BaseTensorType inputType, Value scalar); + +LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, + int64_t dimB, Type &transposedType); + +// Approximates the heuristic in the torch `acc_type` template for kernels +// that are defined in terms of it. For now, this just returns accumulators +// as if for CUDA from that implementation. In the future, this could be +// extended to look at hints on the `forOp` or its container to better +// control the behavior. Such support would be done in coordination with +// the fx_importer and APIs, which could add hints to the IR (based on +// Torch flags, user options, etc). +Type getDefaultAccType(PatternRewriter &rewriter, Type inputType); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index d762bd840f7f..2f70cf990219 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -54,7 +54,7 @@ createVerifyStablehloBackendContractPass(); std::unique_ptr> createFuncBackendTypeConversionPass(); -std::unique_ptr> +std::unique_ptr> createFinalizingBackendTypeConversionPass(); // These passes do a one-off conversion of a specific kind of quantized group @@ -62,8 +62,10 @@ createFinalizingBackendTypeConversionPass(); // obviate them but that are being carried for now in order to unblock progress // on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for // the plan to support a more generalized lowering for these graphs. -std::unique_ptr> createUnpackQuantTensorPass(); -std::unique_ptr> createConvertCustomQuantOpPass(); +std::unique_ptr> +createUnpackQuantTensorPass(); +std::unique_ptr> +createConvertCustomQuantOpPass(); std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 4d3e16a81c5c..73654c6f8034 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -22,7 +22,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu } def FinalizingBackendTypeConversion - : Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> { + : InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> { let summary = "Finalizes a partial conversion to builtin tensors"; let constructor = "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()"; @@ -51,12 +51,12 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra // The following passes are for a one-off conversion of a specific kind of quantized group matmul. // They should not be included in default lowering flows until further along. -def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { +def UnpackQuantTensor : InterfacePass<"torch-unpack-quant-tensor", "mlir::FunctionOpInterface"> { let summary = "Unpack quantized int4 tensor from int8 containter"; let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()"; } -def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> { +def ConvertCustomQuantOp : InterfacePass<"torch-convert-custom-quant-op", "mlir::FunctionOpInterface"> { let summary = "Convert torch custom quant op to linalg"; let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()"; } diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index 8f1b2b525a22..be5e43a1e63c 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -31,6 +31,8 @@ std::unique_ptr> createMLProgramBufferizePass(); std::unique_ptr> createMungeMemrefCopyPass(); +std::unique_ptr> createGeneralizeTensorConcatPass(); + std::unique_ptr> createGeneralizeTensorPadPass(); } // namespace RefBackend } // namespace torch diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 12d182e49e3a..3d8b7fd41b1b 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -35,6 +35,11 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> { let dependentDialects = ["memref::MemRefDialect"]; } +def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> { + let summary = "Convert tensor.concat to other tensor ops"; + let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()"; +} + def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> { let summary = "Convert tensor.pad to linalg ops"; let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; diff --git a/lib/CAPI/Dialects.cpp b/lib/CAPI/Dialects.cpp index 06be821c0cfd..048e37e083a3 100644 --- a/lib/CAPI/Dialects.cpp +++ b/lib/CAPI/Dialects.cpp @@ -9,7 +9,8 @@ #include "torch-mlir-c/Dialects.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "mlir/CAPI/Registration.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, mlir::torch::Torch::TorchDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, + mlir::torch::Torch::TorchDialect) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index d9030c23a66f..e4ba46138f34 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,3 +1,5 @@ +torch_mlir_enable_werror() + add_subdirectory(CAPI) add_subdirectory(Conversion) add_subdirectory(Dialect) @@ -29,6 +31,10 @@ set(LinkedLibs TorchMLIRTorchOnnxToTorch ) +if(TORCH_MLIR_ENABLE_STABLEHLO) +list(APPEND LinkedLibs StablehloPasses StablehloLinalgTransforms) +endif() + if(TORCH_MLIR_ENABLE_REFBACKEND) add_subdirectory(RefBackend) list(APPEND LinkedLibs TorchMLIRRefBackend) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index afbe775d3a20..dd9e94a50080 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(TorchOnnxToTorch) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) +add_subdirectory(TorchToTensor) add_subdirectory(TorchToTosa) if(TORCH_MLIR_ENABLE_STABLEHLO) add_subdirectory(TorchToStablehlo) @@ -14,6 +15,7 @@ add_subdirectory(Utils) set(linked_libs TorchMLIRTorchToLinalg TorchMLIRTorchToSCF TorchMLIRTorchToArith + TorchMLIRTorchToTensor TorchMLIRTorchToTosa TorchMLIRTorchToTMTensor TorchMLIRTorchConversionToMLProgram diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 0dae24678a4b..6d8adbaa146d 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -13,12 +13,13 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif // TORCH_MLIR_ENABLE_STABLEHLO +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" //===----------------------------------------------------------------------===// // Pass registration @@ -29,6 +30,4 @@ namespace { #include "torch-mlir/Conversion/Passes.h.inc" } // end namespace -void mlir::torch::registerConversionPasses() { - ::registerPasses(); -} +void mlir::torch::registerConversionPasses() { ::registerPasses(); } diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index eab81c2bec18..6a00e5190f4b 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -82,7 +82,8 @@ class ConvertGetNextSeedOp : public OpConversionPattern { // temp = multiplier * currentSeed + incrementStep Value mul = rewriter.create(loc, currentSeed, multiplier); Value seed = rewriter.create(loc, mul, incrementStep); - globalVar = rewriter.create(loc, seed, globalVar, ValueRange()); + globalVar = + rewriter.create(loc, seed, globalVar, ValueRange()); rewriter.create( loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar); diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt index 807db64eac64..4a5015816609 100644 --- a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch Passes.cpp Patterns.cpp TorchOnnxToTorch.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 44ced9eb4b64..2e3f3e8b8053 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -7,13 +7,77 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; +class Endian { +private: + static constexpr uint32_t uint32_ = 0x01020304; + static constexpr uint8_t magic_ = (const uint8_t &)uint32_; + +public: + static constexpr bool little = magic_ == 0x04; + static constexpr bool big = magic_ == 0x01; + static_assert(little || big, "Cannot determine endianness!"); + +private: + Endian() = delete; +}; + +static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { + // TODO: Add complete mapping. + // Where are the ONNX and PyTorch dtype enums defined? + // ONNX: + // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto + // PyTorch: + // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88 + + int64_t dtypeIntTorch = [dtypeIntOnnx]() { + switch (dtypeIntOnnx) { + case 1: + return 6; // float + case 7: + return 5; // int64 + case 9: + return 11; // bool + case 10: + return 5; // half + case 11: + return 7; // double + case 16: + return 15; // bfloat16 + default: + return -1; // No dtype + } + }(); + + return dtypeIntTorch; +} + +static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, + Location loc, Value input, + int64_t dimA, int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(input.getType().cast(), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); +} + // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with @@ -39,7 +103,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - // TODO: Acosh unimplemented in torch-mlir // Add became forward compatible with Torch in version 7. patterns.onOp("Add", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -139,9 +202,28 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand, constAxis, constKeepDims); return success(); }); - // TODO: Asin unimplemented in torch-mlir - // TODO: Asinh unimplemented in torch-mlir - // TODO: Atanh unimplemented in torch-mlir + patterns.onOp("Asin", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Asinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -153,6 +235,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Atanh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -164,48 +257,254 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Acosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("BatchNormalization", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, weight, bias, runningMean, runningVar; + bool training; + float momentum, eps; + if (binder.s64BoolAttr(training, "training_mode", 0)) + return failure(); + if (training) { + // TODO: Add support for training = true + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: training = true"); + } + + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.tensorOperandAtIndex(bias, 2) || + binder.tensorOperandAtIndex(runningMean, 3) || + binder.tensorOperandAtIndex(runningVar, 4) || + binder.f32FloatAttr(momentum, "momentum", 0.9f) || + binder.f32FloatAttr(eps, "epsilon", 1e-05f) || + binder.tensorResultType(resultType)) + return failure(); + + Value cstFalse = rewriter.create( + binder.getLoc(), false); + Value cstMomentum = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(momentum)); + Value cstEps = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(eps)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, runningMean, + runningVar, /*training=*/cstFalse, cstMomentum, cstEps, + /*cudnn_enabled=*/cstFalse); + return success(); + }); patterns.onOp( - "BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "AveragePool", 19, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + SmallVector dilation; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + if (binder.s64IntegerArrayAttr(dilation, "dilations", {})) { + return failure(); + } + if (dilation.size() > 0) { + return rewriter.notifyMatchFailure( + binder.op, "dilation is not supported by torch.aten.avgpool op"); + } + Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType) || - binder.customOpNameStringAttr(direction, "direction", "")) + Value operand; + bool ceilMode, countIncludePad; + if (binder.tensorOperand(operand) || + binder.s64BoolAttr(ceilMode, "ceil_mode", false) || + binder.s64BoolAttr(countIncludePad, "count_include_pad", false) || + binder.tensorResultType(resultType)) + return failure(); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector kernel, padding, strides; + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) { return failure(); - if (direction == "LEFT") { - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - } else { - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); } - return success(); + if (kernel.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + } + SmallVector defaultPadding(2 * (rank - 2), 0); + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + if (padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, + "padding list size does not match twice the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", {1})) { + return failure(); + } + if (strides.size() != 1 && strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + + SmallVector cstKernel, cstPadding, cstStrides; + for (int64_t i : kernel) { + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + Value cstCountIncludePad = rewriter.create( + binder.getLoc(), countIncludePad); + Value cstNone = rewriter.create(binder.getLoc()); + + if (rank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad); + return success(); + } else if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } else if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } + return failure(); }); patterns.onOp( - "BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Bernoulli", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || + Value input; + int64_t dtypeIntOnnx, dtypeIntTorch; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) || binder.tensorResultType(resultType)) return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); + + SmallString<64> name("torch.onnx."); + name.append("seed"); + auto attr = binder.op->getAttr(name); + if (attr) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + } + + Value none = rewriter.create(binder.getLoc()); + Value bernoulli = rewriter.create( + binder.getLoc(), input.getType(), input, /*generator=*/none); + + if (dtypeIntOnnx == -1) { + // True, if dtype attribute value is not present. + rewriter.replaceOp(binder.op, bernoulli); + return success(); + } + dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (dtypeIntTorch == -1) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch)); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + rewriter.replaceOpWithNewOp( + binder.op, resultType, bernoulli, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); return success(); }); patterns.onOp( - "BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(direction, "direction", "")) return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); + if (direction == "LEFT") { + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + } else { + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + } return success(); }); + patterns.onOp("BitwiseAnd", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("BitwiseOr", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp("BitwiseNot", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -217,20 +516,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("BitwiseXor", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp( - "BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); - patterns.onOp( - "Cast", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; int64_t dtypeIntOnnx, dtypeIntTorch; @@ -239,24 +538,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); - // TODO: Add complete mapping. - switch (dtypeIntOnnx) { - case 1: - dtypeIntTorch = 6; // float - break; - case 10: - dtypeIntTorch = 5; // half - break; - case 11: - dtypeIntTorch = 7; // double - break; - case 16: - dtypeIntTorch = 15; // bfloat16 - break; - default: - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented support for the given dtype conversion"); + dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (dtypeIntTorch == -1) { + auto message = llvm::formatv("unimplemented support for the given " + "dtype conversion (onnx 'type' = {0})", + dtypeIntOnnx); + auto y = rewriter.notifyMatchFailure(binder.op, message); + + return y; } Value constDtype = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -271,6 +560,36 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( /*memory_format=*/none); return success(); }); + patterns.onOp( + "CastLike", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, target; + if (binder.tensorOperands(input, target) || + binder.tensorResultType(resultType)) + return failure(); + + // TODO: Add support to handle the `saturate` attribute. + // Ignoring it right now, since it's only using during the float8 + // conversions which are not supported in Torch-MLIR right now. + + Torch::ValueTensorType targetTy = + target.getType().cast(); + if (!targetTy.hasDtype()) { + return rewriter.notifyMatchFailure(binder.op, + "target tensor must have a dtype"); + } + Type targetDtype = targetTy.getDtype(); + Value constDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), targetDtype); + Value none = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + return success(); + }); patterns.onOp("Ceil", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -283,39 +602,553 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Clip", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Celu", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + float alpha; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.f32FloatAttr(alpha, "alpha", 1.0f)) + return failure(); + // exp(x/alpha) + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + Value xDivAlpha = rewriter.create( + binder.getLoc(), resultType, operand, constAlpha); + Value expXDivAlpha = rewriter.create( + binder.getLoc(), resultType, xDivAlpha); + // alpha * (exp(x/alpha) - 1) + Value constantOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value subOne = rewriter.create( + binder.getLoc(), resultType, expXDivAlpha, constantOne, + constantOne); + Value mulAlpha = rewriter.create( + binder.getLoc(), resultType, subOne, constAlpha); + Value constantZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantZero); + // min(0, alpha * (exp(x/alpha) - 1)) + Value minExpression = rewriter.create( + binder.getLoc(), resultType, zeroTensor, mulAlpha); + + // max(0, x) + Value maxExpression = rewriter.create( + binder.getLoc(), resultType, zeroTensor, operand); + // max(0,x) + min(0, alpha * (exp(x/alpha) - 1)) + rewriter.replaceOpWithNewOp( + binder.op, resultType, maxExpression, minExpression, constantOne); + return success(); + }); + patterns.onOp( + "Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // https://onnx.ai/onnx/operators/onnx__Clip.html + + // Inputs and outputs must be tensors. + Value source; Torch::ValueTensorType resultType; - if (binder.op->getNumOperands() == 1) { - Value source; - if (binder.tensorOperand(source) || - binder.tensorResultType(resultType)) + if (binder.tensorOperandAtIndex(source, 0) || + binder.tensorResultType(resultType)) { + return failure(); + } + + // Min and max can be args (version 11+) or attributes (version 6-). + // They default to numeric_limits::lowest() and numeric_limits::max(). + Value min; + Value max; + if (binder.op->getNumOperands() >= 2) + min = binder.op->getOperand(1); + if (binder.op->getNumOperands() == 3) + max = binder.op->getOperand(2); + + // Note: attribute versions of the op only support float types. + auto resultDtype = resultType.getDtype(); + if (!min && binder.op->hasAttr("torch.onnx.min")) { + float minValue; + if (binder.f32FloatAttr(minValue, "min", + std::numeric_limits::lowest())) return failure(); - Value cstNone = - rewriter.create(binder.getLoc()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, source, /*min=*/cstNone, /*max=*/cstNone); - return success(); - } else if (binder.op->getNumOperands() == 2) { - Value source, min; - if (binder.tensorOperands(source, min) || - binder.tensorResultType(resultType)) + auto minSplatAttr = SplatElementsAttr::get( + resultType.toBuiltinTensor().clone(resultDtype), + rewriter.getFloatAttr(resultDtype, minValue)); + min = rewriter.create( + binder.getLoc(), resultType, minSplatAttr); + } + if (!max && binder.op->hasAttr("torch.onnx.max")) { + float maxValue; + if (binder.f32FloatAttr(maxValue, "max", + std::numeric_limits::max())) return failure(); + auto maxSplatAttr = SplatElementsAttr::get( + resultType.toBuiltinTensor().clone(resultDtype), + rewriter.getFloatAttr(resultDtype, maxValue)); + max = rewriter.create( + binder.getLoc(), resultType, maxSplatAttr); + } + + if (!min && !max) { + // Cliping with no limits is a no-op. + rewriter.replaceOp(binder.op, source); + return success(); + } + + if (!max) { rewriter.replaceOpWithNewOp( - binder.op, resultType, source, /*min=*/min); + binder.op, resultType, source, min); return success(); - } else if (binder.op->getNumOperands() == 3) { - Value source, min, max; - if (binder.tensorOperandAtIndex(source, 0) || - binder.tensorOperandAtIndex(min, 1) || - binder.tensorOperandAtIndex(max, 2) || - binder.tensorResultType(resultType)) + } + + rewriter.replaceOpWithNewOp( + binder.op, resultType, source, min, max); + return success(); + }); + patterns.onOp( + "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector tensors; + int64_t dim; + if (binder.tensorOperands(tensors, binder.op->getNumOperands()) || + binder.s64IntegerAttr(dim, "axis", 0) || + binder.tensorResultType(resultType)) + return failure(); + Type listElemType = + tensors[0] + .getType() + .cast() + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + binder.op->getLoc(), listType, tensors); + Value cstDim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dim)); + rewriter.replaceOpWithNewOp(binder.op, resultType, + tensorList, cstDim); + return success(); + }); + patterns.onOp( + "Constant", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + if (binder.tensorResultType(resultType)) + return failure(); + auto dtype = resultType.getDtype(); + + float floatValue; + if (binder.op->hasAttr("torch.onnx.value_float") && + !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { + auto splatAttr = + SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + rewriter.getFloatAttr(dtype, floatValue)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, splatAttr); + return success(); + } + + int64_t intValue; + if (binder.op->hasAttr("torch.onnx.value_int") && + !binder.s64IntegerAttr(intValue, "value_int", 0)) { + auto splatAttr = + SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + rewriter.getIntegerAttr(dtype, intValue)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, splatAttr); + return success(); + } + + if (DenseResourceElementsAttr attr = + binder.op->getAttr("torch.onnx.value") + .dyn_cast_or_null()) { + // Bytes are stored in little endian order. Big endian support will + // require swizzling. + if (!Endian::little) { + binder.op->emitError( + "unimplemented: importing on big endian systems"); return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, source, min, max); + } + + auto ty = cast(attr.getType()); + ElementsAttr denseAttr; + auto ptr = attr.getRawHandle().getBlob()->getData(); + if (cast(attr.getType()).getElementType().isInteger(1)) { + llvm::SmallVector newContents; + for (auto val : ptr) { + APInt apval(1, val); + newContents.push_back(apval); + } + denseAttr = DenseElementsAttr::get(ty, newContents); + } else { + denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + } + + rewriter.replaceOpWithNewOp( + binder.op, resultType, denseAttr); + return success(); + } + + if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value") + .dyn_cast_or_null()) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, attr); return success(); } + + llvm::SmallVector intValues; + if (!binder.s64IntegerArrayAttr(intValues, "value_ints", {}) && + !intValues.empty()) { + llvm::SmallVector apValues; + for (auto intVal : intValues) { + apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); + } + auto attr = DenseElementsAttr::get( + resultType.toBuiltinTensor().clone(dtype), apValues); + rewriter.replaceOpWithNewOp( + binder.op, resultType, attr); + return success(); + } + return failure(); }); + patterns.onOp( + "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + + Torch::ValueTensorType resultType; + Value input, weight; + int64_t group; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto weightTensorType = weight.getType().cast(); + if (!weightTensorType || !weightTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + } + ArrayRef weightShape = weightTensorType.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) { + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations; + SmallVector defaultPadding, defaultStrides, defaultDilations; + for (unsigned i = 0; i < rank - 2; i++) { + defaultPadding.push_back(0); + defaultStrides.push_back(1); + defaultDilations.push_back(1); + } + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) { + return failure(); + } + if (dilations.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { + return failure(); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + + SmallVector cstPadding, cstStrides, cstDilations, + cstOutputPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + } + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + for (int64_t i : dilations) { + cstDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + cstOutputPadding = {cstZero, cstZero}; + + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value outputPaddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstOutputPadding); + Value transposed = + rewriter.create(binder.getLoc(), false); + Value bias; + if (binder.op->getNumOperands() == 3) { + if (binder.tensorOperandAtIndex(bias, 2)) { + return failure(); + } + } else { + bias = rewriter.create(binder.getLoc()); + } + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); + patterns.onOp( + "ConvTranspose", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + SmallVector outputShape; + if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {})) + return failure(); + if (outputShape.size()) { + // TODO: Add support for non-None output_shape value. + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: output_shape should be absent"); + } + Torch::ValueTensorType resultType; + Value input, weight; + int64_t group; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto weightTensorType = weight.getType().cast(); + if (!weightTensorType || !weightTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + } + ArrayRef weightShape = weightTensorType.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) { + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations, outputPadding; + SmallVector defaultPadding, defaultStrides, defaultDilations, + defaultOutputPadding; + for (unsigned i = 0; i < rank - 2; i++) { + defaultPadding.push_back(0); + defaultStrides.push_back(1); + defaultDilations.push_back(1); + defaultOutputPadding.push_back(0); + } + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) { + return failure(); + } + if (dilations.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { + return failure(); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(outputPadding, "output_padding", + defaultOutputPadding)) { + return failure(); + } + if (outputPadding.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "output_padding list size does not match the number of axes"); + } + + SmallVector cstPadding, cstStrides, cstDilations, + cstOutputPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + } + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + for (int64_t i : dilations) { + cstDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : outputPadding) { + cstOutputPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value outputPaddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstOutputPadding); + Value transposed = + rewriter.create(binder.getLoc(), true); + Value bias; + if (binder.op->getNumOperands() == 3) { + if (binder.tensorOperandAtIndex(bias, 2)) { + return failure(); + } + } else { + bias = rewriter.create(binder.getLoc()); + } + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); patterns.onOp("Cos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -327,11 +1160,236 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Cosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value operand; + Value axisTensor; + if (binder.tensorOperands(operand, axisTensor) || + binder.tensorResultType(resultType)) + return failure(); + + int64_t exclusive; + int64_t reverse; + // if bind succeeds and either is set, fail because not implemented + if (!binder.s64IntegerAttr(exclusive, "exclusive", 0)) + if (exclusive != 0) + return rewriter.notifyMatchFailure( + binder.op, "unsupported onnx.CumSum conversion: exclusive"); + if (!binder.s64IntegerAttr(reverse, "reverse", 0)) + if (reverse != 0) + return rewriter.notifyMatchFailure( + binder.op, "unsupported onnx.CumSum conversion: reverse"); + + // deal with neg axis: if (axis < 0) axis += rank + int64_t rank = + cast(operand.getType()).getSizes().size(); + Value rankVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank)); + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + Value axisScalar = rewriter.create( + binder.getLoc(), rewriter.getType(), axisTensor); + Value isNegative = rewriter.create( + binder.getLoc(), axisScalar, zero); + isNegative = + rewriter.create(binder.getLoc(), isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + Value dim = rewriter.create( + binder.getLoc(), axisScalar, finalOffset); + + Torch::BaseTensorType resultTensorType = + resultType.cast(); + if (!resultTensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + binder.op, "expected result type to have a dtype"); + } + // resultTensorType.print(llvm::outs()); + Value none = rewriter.create(loc); + rewriter.replaceOpWithNewOp(binder.op, resultType, + operand, dim, none); + return success(); + }); + patterns.onOp( + "DepthToSpace", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t blockSize; + std::string mode; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(blockSize, "blocksize") || + binder.customOpNameStringAttr(mode, "mode", "DCR") || + binder.tensorResultType(resultType)) + return failure(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + SmallVector inputSizes{inputTy.getSizes()}; + if (inputSizes.size() != 4) { + return rewriter.notifyMatchFailure(binder.op, + "Expected input rank to be 4"); + } + Value b = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); + Value c = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1))); + Value h = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2))); + Value w = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(3))); + Value cstBlockSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); + Value cstBlockSizeSquare = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); + Value cDivBlockSizeSquare = rewriter.create( + binder.getLoc(), c, cstBlockSizeSquare); + cDivBlockSizeSquare = rewriter.create( + binder.getLoc(), cDivBlockSizeSquare); + Value reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cstBlockSize, cstBlockSize, + cDivBlockSizeSquare, h, w}); + int64_t cDivBlockSizeSquareInt = + inputSizes[1] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[1] / (blockSize * blockSize); + SmallVector reshapeSizesInt{ + inputSizes[0], blockSize, blockSize, + cDivBlockSizeSquareInt, inputSizes[2], inputSizes[3]}; + Value reshapedInput = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(reshapeSizesInt, + inputTy.getOptionalDtype()), + input, reshapeSizesList); + + Value transposedInput; + if (mode == "DCR") { + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), reshapedInput, + /*dimA=*/1, /*dimB=*/3, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), transposedInput, + /*dimA=*/2, /*dimB=*/4, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + } else { + // mode == "CRD" + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), reshapedInput, + /*dimA=*/2, /*dimB=*/4, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), transposedInput, + /*dimA=*/3, /*dimB=*/4, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + } + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), transposedInput, + /*dimA=*/4, /*dimB=*/5, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + + Value hMulBlockSize = rewriter.create( + binder.getLoc(), h, cstBlockSize); + Value wMulBlockSize = rewriter.create( + binder.getLoc(), w, cstBlockSize); + reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cDivBlockSizeSquare, hMulBlockSize, + wMulBlockSize}); + rewriter.replaceOpWithNewOp( + binder.op, resultType, transposedInput, reshapeSizesList); + return success(); + }); + patterns.onOp( + "DequantizeLinear", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType)) + return failure(); + + Value operand = operands[0]; + Value scale = operands[1]; + Value zeropoint = operands[2]; + + auto operandTy = operand.getType().cast(); + + auto scaleTy = scale.getType().dyn_cast(); + if (!scaleTy || !scaleTy.hasSizes()) + return rewriter.notifyMatchFailure(binder.op, "requires known rank"); + if (!resultType.hasDtype()) + return rewriter.notifyMatchFailure(binder.op, + "requires known resulty dtype"); + + if (scaleTy.getSizes().size() == 0) { + Type qTy = operandTy.getDtype(); + + if (qTy.isUnsignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(32)) { + qTy = rewriter.getType(); + } else { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } + + auto qTensorTy = rewriter.getType( + resultType.getOptionalSizes(), qTy); + scale = rewriter.create( + binder.getLoc(), rewriter.getType(), scale); + zeropoint = rewriter.create( + binder.getLoc(), rewriter.getType(), zeropoint); + + auto quantize = + rewriter.create( + binder.getLoc(), qTensorTy, operand, scale, zeropoint); + rewriter.replaceOpWithNewOp( + binder.op, resultType, quantize); + return success(); + } + + return failure(); + }); patterns.onOp("Div", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; - std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); @@ -339,7 +1397,89 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("Equal", 19, + patterns.onOp( + "Dropout", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + int64_t numOperands = binder.op->getNumOperands(); + SmallVector operands; + int64_t seed; + if (binder.tensorOperands(operands, numOperands) || + binder.s64IntegerAttr(seed, "seed", 0) || + binder.tensorResultTypeAtIndex(resultType, 0)) + return failure(); + + // Global Seed value is 0. + if (seed != 0) { + return rewriter.notifyMatchFailure(binder.op, + "expected seed value to be 0"); + } + + Value ratio, trainingMode; + if (numOperands == 3) { + ratio = rewriter.create(loc, operands[1]); + Value trainVal = operands[2]; + auto trainTensorType = + trainVal.getType().dyn_cast(); + if (!trainTensorType) + return rewriter.notifyMatchFailure(binder.op, + "train tensor must have a type"); + + Type inputDtype = trainTensorType.getOptionalDtype(); + if (!inputDtype || !inputDtype.isInteger(1)) + return rewriter.notifyMatchFailure( + binder.op, + "train tensor must have an integer dtype of width 1"); + + std::optional inputRank = Torch::getTensorRank(trainVal); + if (!inputRank || *inputRank != 0) + return rewriter.notifyMatchFailure(binder.op, + "train tensor must have rank 0"); + + if (auto valueTensorLiteralOp = + trainVal.getDefiningOp()) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + trainingMode = rewriter.create(loc, val); + } else { + Value trainingModeScalar = + rewriter.create(loc, operands[2]); + Value cstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + trainingMode = rewriter.create( + loc, trainingModeScalar, cstOne); + } + } else if (numOperands == 2) { + ratio = rewriter.create(loc, operands[1]); + trainingMode = rewriter.create(loc, false); + } else { + ratio = rewriter.create( + loc, rewriter.getF64FloatAttr(0.5)); + trainingMode = rewriter.create(loc, false); + } + + Value dropout = rewriter.create( + loc, resultType, /*input=*/operands[0], ratio, trainingMode); + + if (binder.op->getNumResults() == 1) { + rewriter.replaceOp(binder.op, dropout); + return success(); + } + Torch::ValueTensorType maskType; + if (binder.tensorResultTypeAtIndex(maskType, 1)) + return failure(); + Value dtype = rewriter.create( + loc, rewriter.getI64IntegerAttr( + (int64_t)torch_upstream::ScalarType::Bool)); + Value none = rewriter.create(loc); + Value mask = rewriter.create( + loc, maskType, operands[0], dtype, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + rewriter.replaceOp(binder.op, {dropout, mask}); + return success(); + }); + patterns.onOp("Equal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; @@ -351,6 +1491,199 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("Elu", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value input; + float alpha; + if (binder.tensorOperand(input) || + binder.f32FloatAttr(alpha, "alpha") || + binder.tensorResultType(resultType)) + return failure(); + Value cstAlpha = rewriter.create( + loc, rewriter.getF64FloatAttr(alpha)); + Value cstOne = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstAlpha, /*scale=*/cstOne, + /*input_scale=*/cstOne); + return success(); + }); + patterns.onOp("Erf", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + std::string direction; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Exp", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // uses ideas and code from onnx.Reshape + auto loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value data, shape; + if (binder.tensorOperands(data, shape) || + binder.tensorResultType(resultType)) + return failure(); + + auto dataType = cast(data.getType()); + auto shapeType = cast(shape.getType()); + if (!dataType.hasSizes() || !shapeType.hasSizes()) + return failure(); + + auto shapeSizes = shapeType.getSizes(); + int64_t dataRank = dataType.getSizes().size(); + int64_t shapeRank = shapeSizes.size(); + if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) + return failure(); + + auto rankDifference = dataRank - shapeSizes[0]; + + SmallVector selectSizes; + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the + // dimension size + // A constant zero value + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + // Variable to store pytorch int list of shape (dimension) + SmallVector dimList; + + // Convert the shape tensor from vector of int64_t to torch int list as + // we are using torch implementation Torch::AtenBroadcastToOp which + // takes list of int + for (int i = 0; i < shapeSizes[0]; i++) { + Value selectIndex = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + loc, selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + loc, rewriter.getType(), extract); + + if (i + rankDifference >= 0) { + Value iv = + rewriter.create(loc, i + rankDifference); + auto sz = rewriter.create( + loc, rewriter.getType(), data, iv); + dim = rewriter.create(loc, dim, sz); + } + + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList); + return success(); + }); + patterns.onOp( + "Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // Flatten means to partition the input tensor's dimensions + // into a "left range" spanning 0 to axis - 1 and a "right range" + // spanning axis to rank - 1. Each range is then collapsed + // into a single dimension, resulting in a 2-D tensor. + // If either range is empty, it is replaced with a single + // dimension of size 1. + // + // For example, for a 4-D input tensor of shape (a, b, c, d) + // and axis==2, flatten produces a 2-D tensor of shape + // (a*b, c*d). + // + // If instead axis==0, the left range is empty, and the result + // is (1, a*b*c*d). + + Torch::ValueTensorType resultType; + Value operand; + int64_t axis; + if (binder.tensorOperand(operand) || + binder.s64IntegerAttr(axis, "axis", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto operandTy = cast(operand.getType()); + llvm::SmallVector shape(operandTy.getSizes()); + int64_t rank = shape.size(); + + // If axis is negative, count from the right instead of left + if (axis < 0) + axis = rank + axis; + + // We collapse in the dimensions to the right of the axis. + for (int i = axis + 1; i < rank; ++i) { + bool dynamic = shape[axis] == Torch::kUnknownSize || + shape[i] == Torch::kUnknownSize; + if (dynamic) { + shape[axis] = Torch::kUnknownSize; + } else { + shape[axis] = shape[axis] * shape[i]; + } + } + + shape.resize(axis + 1, 1); + + auto baseType = rewriter.getType( + shape, operandTy.getDtype()); + Value collapsedRight; + if (axis >= rank) { + // If the right range is empty, add a dim of size 1 to the + // right side of the shape: + // cr = torch.unsqueeze(x, x.ndim) + Value rankConst = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(rank)); + collapsedRight = rewriter.create( + binder.getLoc(), baseType, operand, rankConst); + } else { + // Otherwise, collapse the right range into a single dimension: + // cr = torch._prims.collapse(x, axis, x.ndim - 1) + Value axisConst = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value rankLess1Const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); + collapsedRight = rewriter.create( + binder.getLoc(), baseType, operand, axisConst, rankLess1Const); + } + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + + if (axis <= 0) { + // If the left range is empty, add a dim of size 1 to the + // left side of the shape: + // torch.unsqueeze(cr, 0) + rewriter.replaceOpWithNewOp( + binder.op, resultType, collapsedRight, zero); + return success(); + } + + // Otherwise, collapse the left range into a single dimension: + // torch._prims.collapse(cr, 0, axis - 1) + Value axisLess1Const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, collapsedRight, zero, axisLess1Const); + return success(); + }); patterns.onOp("Floor", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -362,4 +1695,108 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "ConstantOfShape", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value shape; + if (binder.tensorOperand(shape) || binder.tensorResultType(resultType)) + return failure(); + + // convert shape tensor to list of ints + auto shapeSizes = + dyn_cast(shape.getType()).getSizes(); + SmallVector dimList; + Torch::BaseTensorType shapeType = + shape.getType().cast(); + Type selectResultType = rewriter.getType( + ArrayRef({}), shapeType.getOptionalDtype()); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + for (int i = 0; i < shapeSizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value noneVal = rewriter.create(binder.getLoc()); + + // Get fill_value if it is present. + // Assumption : resultDType and value attr type match. + auto attr = binder.op->getAttr("torch.onnx.value"); + auto resultDType = resultType.getDtype(); + + // Extract the fill value and dtype + // ONNX requires value attr to be a tensor + if (!attr) { + attr = DenseElementsAttr::get( + resultType.toBuiltinTensor().clone(resultDType), + rewriter.getFloatAttr(resultDType, 0.0)); + } + + // If its a dense resource attr we need to convert to a dense type: + if (DenseResourceElementsAttr rattr = + attr.dyn_cast_or_null()) { + // Bytes are stored in little endian order. Big endian support will + // require swizzling. + if (!Endian::little) { + binder.op->emitError( + "unimplemented: importing on big endian systems"); + return failure(); + } + + auto ty = cast(rattr.getType()); + auto ptr = rattr.getRawHandle().getBlob()->getData(); + auto denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + attr = dyn_cast_or_null(denseAttr); + } + + Attribute splattr; + if (isa(attr)) { + auto denseAttr = attr.cast(); + splattr = denseAttr.getSplatValue(); + } + + if (!isa(splattr)) { + return rewriter.notifyMatchFailure( + binder.op, + "`value` attr tensor only supports types int and float for now."); + } + + Value splatvalue; + if (auto intattr = dyn_cast(splattr)) { + IntegerType intty = cast(intattr.getType()); + int64_t value; + if (intty.isUnsignedInteger()) { + value = intattr.getUInt(); + } else if (intty.isSignedInteger()) { + value = intattr.getSInt(); + } else { + value = intattr.getInt(); + } + splatvalue = + rewriter.create(binder.getLoc(), value); + } + + if (auto fpattr = dyn_cast(splattr)) + splatvalue = rewriter.create( + binder.getLoc(), + rewriter.getF64FloatAttr(fpattr.getValueAsDouble())); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, dimValueList, splatvalue, /*dtype=*/noneVal, + /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index af4f06fdef77..a7bdddbc8d78 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -8,6 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; @@ -26,4 +28,1129 @@ using namespace mlir::torch::onnx_c; // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( - OnnxCustomOpConversionPattern &patterns) {} + OnnxCustomOpConversionPattern &patterns) { + patterns.onOp( + "HardSigmoid", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensorOperand; + float alpha, beta; + if (binder.tensorOperand(tensorOperand) || + binder.f32FloatAttr(alpha, "alpha", 0.2f) || + binder.f32FloatAttr(beta, "beta", 0.5f) || + binder.tensorResultType(resultType)) + return failure(); + + // HardSigmoid computes the following expression: + // max(0, min(1, alpha * x + beta)) + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + + Value constBeta = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(beta)); + + // Expression: alpha * x + beta + Value alpha_x_plus_beta = rewriter.create( + binder.getLoc(), resultType, tensorOperand, constBeta, + /*alpha=*/constAlpha); + + // Expression: min(1, alpha * x + beta) + Value constantOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantOne); + Value minExpression = rewriter.create( + binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); + + // Expression: max(0, min(1, alpha * x + beta)) + Value constantZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantZero); + rewriter.replaceOpWithNewOp( + binder.op, resultType, zeroTensor, minExpression); + return success(); + }); + patterns.onOp( + "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value operand; + Torch::ValueTensorType resultType; + std::string approximate; + + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(approximate, "approximate", "none")) + return failure(); + + Value vApproximate = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getStringAttr(approximate)); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + operand, vApproximate); + return success(); + }); + patterns.onOp( + "GridSample", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + Value grid; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(grid, 1) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "operand grid_sampler bind failure"); + + auto inputTensorType = input.getType().cast(); + ArrayRef inputShape = inputTensorType.getSizes(); + uint32_t inputRank = inputShape.size(); + auto gridTensorType = grid.getType().cast(); + ArrayRef gridShape = gridTensorType.getSizes(); + uint32_t gridRank = gridShape.size(); + + if (inputRank != 4) + return rewriter.notifyMatchFailure(binder.op, + "only input rank 4 supported"); + if (gridRank != 4) + return rewriter.notifyMatchFailure(binder.op, + "only grid rank 4 supported"); + if (inputShape[0] != gridShape[0]) + return rewriter.notifyMatchFailure( + binder.op, "N must be same for input and grid"); + if (gridShape[3] != 2) + return rewriter.notifyMatchFailure(binder.op, + "gridShape[3] expected to be 2"); + std::string mode; + if (binder.customOpNameStringAttr(mode, "mode", "bilinear")) + return rewriter.notifyMatchFailure(binder.op, "mode bind failure"); + if (mode != "bilinear") + return rewriter.notifyMatchFailure( + binder.op, "currently only mode : bilinear supported"); + std::string padding; + if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) + return rewriter.notifyMatchFailure(binder.op, + "padding_mode bind failure"); + if (padding != "zeros") + return rewriter.notifyMatchFailure( + binder.op, "currently only padding_mode : zeros supported"); + int64_t align; + if (binder.s64IntegerAttr(align, "align_corners", 0)) + return rewriter.notifyMatchFailure(binder.op, + "align_corners bind failure"); + if (align != 1) + return rewriter.notifyMatchFailure( + binder.op, "currently only align_corners = 1 supported"); + + Value interpolationMode = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value paddingMode = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value alignCorners = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(false)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, grid, interpolationMode, paddingMode, + alignCorners); + return success(); + }); + patterns.onOp("Less", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("LessOrEqual", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("Log", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("MatMul", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "MatMulInteger", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs, lhsZp, rhsZp; + if (binder.tensorOperandAtIndex(lhs, 0) || + binder.tensorOperandAtIndex(rhs, 1) || + binder.tensorResultType(resultType)) + return failure(); + + if (binder.tensorOperandAtIndex(lhsZp, 2)) { + lhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + } + + if (binder.tensorOperandAtIndex(rhsZp, 3)) { + rhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + } + + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + + if (auto zpTy = dyn_cast(lhsZp.getType())) { + for (auto dim : zpTy.getSizes()) + if (dim != 1) + return failure(); + lhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), lhsZp); + } + + if (auto zpTy = dyn_cast(rhsZp.getType())) { + for (auto dim : zpTy.getSizes()) + if (dim != 1) + return failure(); + rhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), rhsZp); + } + + Value scale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(1.0)); + + auto q = [&](Type qty) -> Type { + if (qty.isSignedInteger(8)) + return rewriter.getType(); + if (qty.isUnsignedInteger(8)) + return rewriter.getType(); + if (qty.isSignedInteger(32)) + return rewriter.getType(); + return {}; + }; + + Type lhsQTy = rewriter.getType( + lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); + Type rhsQTy = rewriter.getType( + rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); + + lhs = rewriter.create( + binder.getLoc(), lhsQTy, lhs, scale, lhsZp); + rhs = rewriter.create( + binder.getLoc(), rhsQTy, rhs, scale, rhsZp); + + rewriter.replaceOpWithNewOp(binder.op, resultType, lhs, + rhs); + return success(); + }); + patterns.onOp("Mul", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("NonZero", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return rewriter.notifyMatchFailure(binder.op, + "auto_pad bind failure"); + if (autoPad != "NOTSET") + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + + Torch::ValueTensorType resultType; + Value operand; + bool ceilMode; + int64_t storageOrder; + // TODO: Add support for indices output and storage_order + if (binder.tensorOperand(operand) || + binder.s64BoolAttr(ceilMode, "ceil_mode", false) || + binder.s64IntegerAttr(storageOrder, "storage_order", 0) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, + "operand/ceil_mode/storage_order/resultType bind failure"); + if (storageOrder != 0) + return rewriter.notifyMatchFailure( + binder.op, "storage_order setting is not supported."); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + int64_t rank = *maybeRank; + int64_t spatial = rank - 2; + + SmallVector kernel, padding, strides, dilations; + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) + return rewriter.notifyMatchFailure(binder.op, + "kernel_shape bind failure"); + if (kernel.size() != static_cast(spatial)) + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(padding, "pads", {})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (!padding.empty() && + padding.size() != static_cast(2 * spatial)) + return rewriter.notifyMatchFailure( + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); + if (binder.s64IntegerArrayAttr(strides, "strides", {})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (!strides.empty() && strides.size() != static_cast(spatial)) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) + return rewriter.notifyMatchFailure(binder.op, + "dilations bind failure"); + + if (padding.empty()) + padding.resize(spatial, 0); + if (strides.empty()) + strides.resize(spatial, 1); + if (dilations.empty()) + dilations.resize(spatial, 1); + + // If the padding is symmetric we can push the padding operation to the + // torch operator. + if (padding.size() == static_cast(2 * spatial)) { + bool equal = true; + for (int i = 0; i < spatial; ++i) { + equal = equal && (padding[i] == padding[i + spatial]); + } + if (equal) + padding.resize(spatial); + } + + // Torch pool operators require equal padding on each size of each + // dimension so we materialize the padding behavior explicitly and set + // the padding to 0. + if (padding.size() == static_cast(2 * spatial)) { + auto operandTy = cast(operand.getType()); + llvm::SmallVector shuffledPadding(spatial * 2); + llvm::SmallVector paddedShape(operandTy.getSizes()); + shuffledPadding.resize(2 * rank); + for (int i = 0; i < spatial; ++i) { + paddedShape[i + 2] += padding[i] + padding[i + spatial]; + shuffledPadding[2 * i] = padding[i]; + shuffledPadding[2 * i + 1] = padding[i + spatial]; + } + + Value shuffledPaddingList = + createConstantIntList(binder, rewriter, padding); + Value zero; + if (resultType.getDtype().isa()) { + zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr( + std::numeric_limits::lowest())); + } else if (resultType.getDtype().isa()) { + zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + std::numeric_limits::lowest())); + } + + auto paddedInputTy = rewriter.getType( + paddedShape, operandTy.getDtype()); + operand = rewriter.create( + binder.getLoc(), paddedInputTy, operand, shuffledPaddingList, + zero); + padding.clear(); + padding.resize(spatial, 0); + } + + Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value dilationsList = + createConstantIntList(binder, rewriter, dilations); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + + if (rank == 3) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: AtenMaxPool1dOp"); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); + }); + patterns.onOp("Greater", 16, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("GreaterOrEqual", 16, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "InstanceNormalization", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + float eps; + + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType) || operands.size() != 3 || + binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { + return failure(); + } + Value none = rewriter.create(binder.getLoc()); + Value boolTrue = + rewriter.create(binder.getLoc(), true); + Value boolFalse = + rewriter.create(binder.getLoc(), false); + auto epsValue = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(eps)); + + auto momentum = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, /* input */ operands[0], + /* weight */ operands[1], + /* bias */ operands[2], /* running mean */ none, + /* running var */ none, + /* use input stats */ boolTrue, momentum, epsValue, + /* cudnn enabled */ boolFalse); + return success(); + }); + patterns.onOp( + "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp( + "Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp("Neg", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Not", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Or", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, indices; + int64_t axis; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(axis, "axis", 0)) + return failure(); + Location loc = binder.getLoc(); + auto ctx = binder.op->getContext(); + auto indicesTy = cast(indices.getType()); + auto dataTy = cast(data.getType()); + if (!dataTy || !dataTy.hasSizes() || !indicesTy.hasSizes()) + return failure(); + + int64_t dataRank = dataTy.getSizes().size(); + int64_t indicesRank = indicesTy.getSizes().size(); + axis = axis < 0 ? axis + dataRank : axis; + + Value index = rewriter.create( + loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis)); + + // Apply bounds checking on the input: + auto intTy = rewriter.getType(); + auto boolTy = rewriter.getType( + indicesTy.getSizes(), rewriter.getI1Type()); + Value zero = rewriter.create( + loc, intTy, rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + loc, intTy, rewriter.getI64IntegerAttr(1)); + Value lt = + rewriter.create(loc, boolTy, indices, zero); + Value dim = + rewriter.create(loc, intTy, data, index); + Value add = rewriter.create(loc, indicesTy, + indices, dim, one); + indices = rewriter.create(loc, indicesTy, lt, + add, indices); + + auto intListTy = rewriter.getType( + rewriter.getType()); + + llvm::SmallVector indicesDims; + for (int i = 0, s = indicesTy.getSizes().size(); i < s; ++i) { + Value k = rewriter.create(binder.getLoc(), i); + indicesDims.push_back(rewriter.create( + binder.getLoc(), indices, k)); + } + + Value indicesSizeList = rewriter.create( + binder.getLoc(), intListTy, indicesDims); + + // Determine the collapsed dim size: + auto indicesCt = 1; + for (auto sz : indicesTy.getSizes()) { + if (sz == Torch::kUnknownSize) { + indicesCt = Torch::kUnknownSize; + break; + } + + indicesCt *= sz; + } + + auto flattenTy = rewriter.getType( + SmallVector{indicesCt}, indicesTy.getOptionalDtype()); + + if (indicesRank == 0) { + indices = rewriter.create( + binder.getLoc(), flattenTy, indices, zero); + } else if (indicesRank > 1) { + Value rank = rewriter.create(loc, intTy, indices); + Value end = rewriter.create(loc, rank, one); + indices = rewriter.create( + loc, flattenTy, indices, zero, end); + } + + llvm::SmallVector gatherShape(dataTy.getSizes()); + gatherShape[axis] = indicesCt; + auto gatherTy = rewriter.getType( + gatherShape, dataTy.getOptionalDtype()); + Value gather = rewriter.create( + loc, gatherTy, data, index, indices); + + if (indicesRank == 1) { + rewriter.replaceOp(binder.op, gather); + return success(); + } + + if (indicesRank > 1) { + gather = rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, index, indicesSizeList); + return success(); + } + + rewriter.replaceOpWithNewOp(binder.op, resultType, + gather); + return success(); + }); + patterns.onOp( + "GatherElements", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, indices; + int64_t axis; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(axis, "axis", 0)) + return failure(); + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value sparseGrad = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, constAxis, indices, sparseGrad); + return success(); + }); + patterns.onOp( + "Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value a, b, c; + float alpha, beta; + int64_t transA, transB; + if (binder.tensorOperandAtIndex(a, 0) || + binder.tensorOperandAtIndex(b, 1) || + binder.s64IntegerAttr(transA, "transA", 0) || + binder.s64IntegerAttr(transB, "transB", 0) || + binder.f32FloatAttr(alpha, "alpha", 1.0f) || + binder.f32FloatAttr(beta, "beta", 1.0f) || + binder.tensorResultType(resultType)) + return failure(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + + auto transpose = [&](Value m) -> Value { + auto tty = m.getType().cast(); + auto shape = tty.getOptionalSizes(); + if (shape.has_value()) { + llvm::SmallVector newShape(shape.value()); + std::reverse(newShape.begin(), newShape.end()); + shape = std::move(newShape); + } + auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, + tty.getOptionalDtype()); + return rewriter.create(binder.getLoc(), + oty, m, zero, one); + }; + + if (transA) { + a = transpose(a); + } + + if (transB) { + b = transpose(b); + } + + if (binder.getNumOperands() == 2) { + rewriter.replaceOpWithNewOp(binder.op, resultType, a, + b); + return success(); + } + + if (binder.tensorOperandAtIndex(c, 2)) + return rewriter.notifyMatchFailure(binder.op, + "Expected either 2 or 3 inputs"); + + Value mm = + rewriter.create(binder.getLoc(), resultType, a, b); + if (alpha == 1.0 && beta == 1.0) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, mm, c, one); + return success(); + } + + if (alpha != 1.0 && beta != 1.0) { + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + mm = rewriter.create( + binder.getLoc(), resultType, mm, constAlpha); + alpha = 1.0; + } + + if (alpha != 1.0) { + std::swap(alpha, beta); + std::swap(mm, c); + } + + Value constBeta = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(beta)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, mm, c, constBeta); + return success(); + }); + patterns.onOp( + "GlobalAveragePool", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTensorType = operand.getType().cast(); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + ArrayRef inputShape = inputTensorType.getSizes(); + unsigned inputRank = inputShape.size(); + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); + } + ArrayRef resultShape = resultType.getSizes(); + + SmallVector cstKernel, cstPadding, cstStrides; + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + for (unsigned i = 2; i < inputRank; i++) { + int64_t kernelSize = inputShape[i] - resultShape[i] + 1; + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); + cstPadding.push_back(cstZero); + cstStrides.push_back(cstOne); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value cstCeilMode = cstFalse; + Value cstCountIncludePad = cstFalse; + Value cstNone = rewriter.create(binder.getLoc()); + + if (inputRank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad); + return success(); + } else if (inputRank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } else if (inputRank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } + return failure(); + }); + patterns.onOp( + "LayerNormalization", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType yType, meanType, invStdDevType; + Value x, scale, b; + int64_t axis, stashType; + float epsilon; + if (binder.tensorOperandAtIndex(x, 0) || + binder.tensorOperandAtIndex(scale, 1) || + binder.tensorOperandAtIndex(b, 2) || + binder.tensorResultTypeAtIndex(yType, 0) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || + binder.s64IntegerAttr(stashType, "stash_type", 1)) + return failure(); + Value constEpsilon = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(epsilon)); + unsigned rank = 1; + if (std::optional maybeRank = Torch::getTensorRank(x)) + rank = *maybeRank; + SmallVector normalized; + axis = Torch::toPositiveDim(axis, rank); + auto xType = x.getType().cast(); + if (!xType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input (X) to have sizes"); + } + ArrayRef xShape = xType.getSizes(); + for (int64_t n = axis; n < rank; n++) { + normalized.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); + } + Value normalized_shape = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + normalized); + + int64_t numResults = binder.op->getNumResults(); + if (numResults == 1) { + SmallVector reducedShape(rank, 1); + for (int64_t i = 0; i < axis; i++) + reducedShape[i] = xShape[i]; + auto reducedType = xType.getWithSizesAndDtype( + reducedShape, xType.getOptionalDtype()); + Value y = rewriter + .create( + binder.getLoc(), yType, /*meanType=*/reducedType, + /*invStdDevType=*/reducedType, x, normalized_shape, + scale, b, constEpsilon) + .getResult0(); + rewriter.replaceOp(binder.op, y); + return success(); + } + if (numResults == 3) { + if (binder.tensorResultTypeAtIndex(meanType, 1) || + binder.tensorResultTypeAtIndex(invStdDevType, 2)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, yType, meanType, invStdDevType, x, normalized_shape, + scale, b, constEpsilon); + return success(); + } + return rewriter.notifyMatchFailure( + binder.op, "Unimplemented: expected either 1 or 3 results"); + }); + patterns.onOp("LeakyRelu", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + float alpha; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.f32FloatAttr(alpha, "alpha", 0.01f)) + return failure(); + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAlpha); + return success(); + }); + patterns.onOp( + "Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, pads, axes; + std::string mode; + + // TODO: The `axes` parameter is not supported yet. + if (!binder.tensorOperandAtIndex(axes, 3)) { + return rewriter.notifyMatchFailure( + binder.op, "The axes parameter is not supported yet"); + } + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(pads, 1) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(mode, "mode", "constant")) + return failure(); + Location loc = binder.getLoc(); + + // Get pads shape and rank. The pads tensor is expected to be 1-D + // tensor. + auto padsTensorType = pads.getType().cast(); + if (!padsTensorType || !padsTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty pad tensor"); + } + ArrayRef padsShape = padsTensorType.getSizes(); + int64_t padsRank = padsShape.size(); + if (padsRank != 1) + return rewriter.notifyMatchFailure(binder.op, + "expect 1-d pad tensor"); + + int64_t padsSize = padsShape[0]; + if (padsSize == Torch::kUnknownSize) + return rewriter.notifyMatchFailure(binder.op, + "pad length is unknown"); + + Value constantValue; + if (binder.getNumOperands() >= 3) { + if (!binder.tensorOperandAtIndex(constantValue, 2)) { + auto constTy = + dyn_cast(constantValue.getType()); + if (!constTy || !constTy.hasDtype()) + return rewriter.notifyMatchFailure( + binder.op, "constant ty is unsupport type"); + + Type scalarTy = rewriter.getType(); + if (isa(constTy.getDtype())) + scalarTy = rewriter.getType(); + constantValue = rewriter.create(loc, scalarTy, + constantValue); + } + } + + if (!constantValue) { + auto dataTensorType = data.getType().cast(); + if (dataTensorType.getDtype().isa()) + constantValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + if (dataTensorType.getDtype().isa()) + constantValue = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0f)); + + if (!constantValue) + return rewriter.notifyMatchFailure( + binder.op, "expected integer or float data tensor"); + } + + // Extract all the values of 1-D pad tensor and create a list of all + // these values as torch.pad op expects pad list. + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + SmallVector padsTensorValue; + SmallVector emptyShape; + Type padsElemType = + Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape, + padsTensorType.getOptionalDtype()); + for (uint32_t i = 0; i < padsSize; ++i) { + Value index = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + auto select = rewriter.create( + loc, padsElemType, pads, constZero, index); + Value selectInt = rewriter.create( + loc, rewriter.getType(), select); + padsTensorValue.push_back(selectInt); + } + + // The torch.pad op expects a different arrangement of padding pairs for + // each dimension as compared to the onnx.pad op. So, rearranging pad + // tensor to satisfy torch.pad op semantics. + SmallVector padsRearrange; + for (uint32_t i = 0; i < padsSize / 2; i++) { + padsRearrange.emplace_back(padsTensorValue[i]); + padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]); + } + + Value padsSizeList = + rewriter + .create( + loc, + Torch::ListType::get(rewriter.getType()), + padsRearrange) + .getResult(); + Value modeVal = rewriter.create( + loc, rewriter.getStringAttr(mode)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, padsSizeList, modeVal, constantValue); + return success(); + }); + patterns.onOp("Pow", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + if (binder.tensorOperand(tensor) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value noneVal = rewriter.create(binder.getLoc()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, tensor, /*memory_format=*/noneVal); + return success(); + }); + patterns.onOp( + "Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + if (binder.op->getNumOperands() == 1) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOp(binder.op, x); + return success(); + } + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t numOperands = binder.op->getNumOperands(); + Value numOperandsConstant = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands)); + if (binder.tensorOperands(valList, numOperands) || + binder.tensorResultType(resultType)) + return failure(); + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + // Short circuit to binary add + Value curr = rewriter.create( + binder.getLoc(), resultType, valList[0], valList[1], constOne); + if (numOperands == 2) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, curr, numOperandsConstant); + return success(); + } + // When binder.op->getNumOperands() > 2 + auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( + binder.op->getContext()); + for (int i = 2; i < numOperands; i++) { + if (i == numOperands - 1) { + curr = rewriter.create( + binder.getLoc(), resultType, curr, valList[i], constOne); + } else { + curr = rewriter.create( + binder.getLoc(), baseType, curr, valList[i], constOne); + } + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, curr, numOperandsConstant); + return success(); + }); + patterns.onOp( + "IsInf", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + int64_t neg; + int64_t pos; + if (binder.tensorOperand(tensor) || + binder.s64IntegerAttr(neg, "detect_negative", 1) || + binder.s64IntegerAttr(pos, "detect_positive", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + if (neg == 0) { + // replace all negative infs with 0 + tensor = rewriter.create( + binder.getLoc(), + dyn_cast(tensor.getType()), tensor); + } + if (pos == 0) { + // first use neg op to flip positive inf to negative inf. Then relu to + // replace all positive infs with 0. + Value flip = rewriter.create( + binder.getLoc(), + dyn_cast(tensor.getType()), tensor); + tensor = rewriter.create( + binder.getLoc(), dyn_cast(flip.getType()), + flip); + } + rewriter.replaceOpWithNewOp(binder.op, resultType, + tensor); + return success(); + }); + patterns.onOp("IsNaN", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + if (binder.tensorOperand(tensor) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, tensor); + return success(); + }); + patterns.onOp("PRelu", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + Value slope; + if (binder.tensorOperands(tensor, slope) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, tensor, slope); + return success(); + }); +} diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 23af89f329ab..b5e9162bc2bf 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -8,6 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::torch; @@ -25,5 +29,2074 @@ using namespace mlir::torch::onnx_c; // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. + +// utilities +// Templatized function to get an item op of a type +namespace { +template +Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, + Value &ofItem) { + return rewriter.create(binder.getLoc(), + rewriter.getType(), ofItem); +} +} // namespace + void mlir::torch::onnx_c::populateDefaultDomainQtoZ( - OnnxCustomOpConversionPattern &patterns) {} + OnnxCustomOpConversionPattern &patterns) { + patterns.onOp( + "QuantizeLinear", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType)) + return failure(); + + Value operand = operands[0]; + Value scale = operands[1]; + Value zeropoint = operands[2]; + + auto scaleTy = scale.getType().dyn_cast(); + if (!scaleTy || !scaleTy.hasSizes()) + return rewriter.notifyMatchFailure(binder.op, "requires known rank"); + if (!resultType.hasDtype()) + return rewriter.notifyMatchFailure(binder.op, + "requires known result dtype"); + + if (scaleTy.getSizes().size() == 0) { + Type qTy = resultType.getDtype(); + + if (qTy.isUnsignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(32)) { + qTy = rewriter.getType(); + } else { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } + + auto qTensorTy = rewriter.getType( + resultType.getOptionalSizes(), qTy); + auto torchqTy = Torch::getScalarTypeForType(qTy); + + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchqTy))); + + scale = rewriter.create( + binder.getLoc(), rewriter.getType(), scale); + zeropoint = rewriter.create( + binder.getLoc(), rewriter.getType(), zeropoint); + + auto quantize = rewriter.create( + binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst); + rewriter.replaceOpWithNewOp( + binder.op, resultType, quantize); + return success(); + } + + return failure(); + }); + patterns.onOp( + "QLinearConv", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if ((binder.tensorOperands(operands, 8) && + binder.tensorOperands(operands, 9)) || + binder.tensorResultType(resultType)) + return failure(); + Value a = operands[0]; + Value aScale = operands[1]; + Value aZp = operands[2]; + Value b = operands[3]; + Value bScale = operands[4]; + Value bZp = operands[5]; + Value cScale = operands[6]; + Value cZp = operands[7]; + Value c = operands.size() == 9 ? operands[8] : nullptr; + + auto check = [](Value v) { + auto vTy = v.getType().cast(); + return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); + }; + if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || + !check(cScale) || !check(cScale)) + return rewriter.notifyMatchFailure( + binder.op, "not supported for non per-tensor quantization"); + + auto extract = [&rewriter, &binder](Value v) { + auto vTy = v.getType().cast(); + Type extractTy = rewriter.getType(); + if (isa(vTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + aZp = extract(aZp); + bZp = extract(bZp); + cZp = extract(cZp); + aScale = extract(aScale); + bScale = extract(bScale); + cScale = extract(cScale); + + auto make = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { + auto ty = v.getType().cast(); + auto newTy = getQTorchTypeFromTorchIntType(ty); + return rewriter.create( + binder.getLoc(), newTy, v, scale, zp); + }; + + a = make(a, aScale, aZp); + b = make(b, bScale, bZp); + + auto cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getIntegerType(32, /*issigned=*/true)); + + // TODO(suderman): insert convolution operator. + llvm::SmallVector newOperands = {a, b}; + if (c) + newOperands.push_back(c); + + cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getType()); + + llvm::SmallVector newAttributes; + newAttributes.push_back( + rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv"))); + for (auto namedAttr : binder.op->getAttrDictionary()) { + if (namedAttr.getName().getValue().compare("name") == 0) + continue; + llvm::errs() << namedAttr.getName() << "\n"; + newAttributes.push_back(namedAttr); + } + + c = rewriter + .create(binder.getLoc(), cTy, newOperands, + newAttributes, + binder.op->getRegions().size()) + .getResult(0); + + Value outScale = rewriter.create( + binder.getLoc(), rewriter.getType(), aScale, + bScale); + Value outZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + c = rewriter.create( + binder.getLoc(), cTy, c, outScale, outZp); + cTy = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF32Type()); + + c = rewriter.create(binder.getLoc(), cTy, + c); + cTy = dyn_cast( + getQTorchTypeFromTorchIntType(resultType)); + Value dtyVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), + static_cast( + Torch::getScalarTypeForType(cTy.getDtype())))); + c = rewriter.create( + binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + rewriter.replaceOpWithNewOp(binder.op, resultType, + c); + return success(); + }); + patterns.onOp( + "QLinearMatMul", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 8) || + binder.tensorResultType(resultType)) + return failure(); + Value a = operands[0]; + Value aScale = operands[1]; + Value aZp = operands[2]; + Value b = operands[3]; + Value bScale = operands[4]; + Value bZp = operands[5]; + Value cScale = operands[6]; + Value cZp = operands[7]; + + auto check = [](Value v) { + auto vTy = v.getType().cast(); + for (auto dim : vTy.getSizes()) + if (dim != 1) + return false; + return true; + }; + if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || + !check(cScale) || !check(cScale)) + return rewriter.notifyMatchFailure( + binder.op, "not supported for non per-tensor quantization"); + + Value emptyList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + ValueRange{}); + auto extract = [&rewriter, &binder, &emptyList](Value v) { + auto vTy = v.getType().cast(); + if (!vTy.getSizes().empty()) { + vTy = rewriter.getType( + ArrayRef({}), vTy.getOptionalDtype()); + v = rewriter.create(binder.getLoc(), vTy, v, + emptyList); + } + + Type extractTy = rewriter.getType(); + if (isa(vTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + aZp = extract(aZp); + bZp = extract(bZp); + cZp = extract(cZp); + aScale = extract(aScale); + bScale = extract(bScale); + cScale = extract(cScale); + + auto make = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { + auto ty = v.getType().cast(); + auto newTy = getQTorchTypeFromTorchIntType(ty); + return rewriter.create( + binder.getLoc(), newTy, v, scale, zp); + }; + + a = make(a, aScale, aZp); + b = make(b, bScale, bZp); + + auto cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getIntegerType(32, /*issigned=*/true)); + + Value c; + if (cTy.getSizes().size() == 2) { + c = rewriter.create(binder.getLoc(), cTy, a, b); + } else { + c = rewriter.create(binder.getLoc(), cTy, a, b); + } + + cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getType()); + + Value mmScale = rewriter.create( + binder.getLoc(), rewriter.getType(), aScale, + bScale); + Value mmZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + c = rewriter.create( + binder.getLoc(), cTy, c, mmScale, mmZp); + cTy = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF32Type()); + + c = rewriter.create(binder.getLoc(), cTy, + c); + cTy = dyn_cast( + getQTorchTypeFromTorchIntType(resultType)); + Value dtyVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), + static_cast( + Torch::getScalarTypeForType(cTy.getDtype())))); + c = rewriter.create( + binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + rewriter.replaceOpWithNewOp(binder.op, resultType, + c); + return success(); + }); + patterns.onOp("Reciprocal", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "Relu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + x); + return success(); + }); + patterns.onOp("Round", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "ScatterElements", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t axis; + std::string reduction; + int64_t numOperands = binder.op->getNumOperands(); + if (binder.tensorOperands(valList, numOperands) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.customOpNameStringAttr(reduction, "reduction", "none") || + binder.tensorResultType(resultType)) + return failure(); + + Value data = valList[0]; + Value indices = valList[1]; + Value updates = valList[2]; + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(data.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + + if (reduction == "none") { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, constAxis, indices, updates); + return success(); + } + + // TODO: Implement max and min cases + if (reduction == "mul") { + reduction = "multiply"; + } else if (reduction == "max" || reduction == "min") { + return rewriter.notifyMatchFailure( + binder.op, "max/min reduction unsupported for scatter elements"); + } + + Value cstStrReduction = + rewriter.create(binder.getLoc(), reduction); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, constAxis, indices, updates, + cstStrReduction); + return success(); + }); + patterns.onOp( + "Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + x); + return success(); + }); + patterns.onOp("Sin", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Tanh", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Sqrt", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "Sub", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + Value y; + if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType)) + return failure(); + Value const1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, x, y, const1); + return success(); + }); + patterns.onOp( + "Sum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + if (binder.op->getNumOperands() == 1) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOp(binder.op, x); + return success(); + } + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t numOperands = binder.op->getNumOperands(); + if (binder.tensorOperands(valList, numOperands) || + binder.tensorResultType(resultType)) + return failure(); + Value const1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + // Short circuit to binary add + if (numOperands == 2) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, valList[0], valList[1], const1); + return success(); + } + // When binder.op->getNumOperands() > 2 + Value curr = rewriter.create( + binder.getLoc(), resultType, valList[0], valList[1], const1); + for (int i = 2; i < numOperands; i++) { + if (i == numOperands - 1) { + curr = rewriter.create( + binder.getLoc(), resultType, curr, valList[i], const1); + } else { + SmallVector resultBroadcastShapeInt; + SmallVector resultBroadcastShapeValue; + Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr, + valList[i], resultBroadcastShapeInt, + resultBroadcastShapeValue); + auto baseType = Torch::ValueTensorType::get( + binder.op->getContext(), resultBroadcastShapeInt, + resultType.getOptionalDtype()); + curr = rewriter.create( + binder.getLoc(), baseType, curr, valList[i], const1); + } + } + rewriter.replaceOp(binder.op, curr); + return success(); + }); + patterns.onOp("Where", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t numOperands = binder.op->getNumOperands(); + if (binder.tensorOperands(valList, numOperands) || + binder.tensorResultType(resultType)) + return failure(); + Value condition = valList[0]; + Value x = valList[1]; + Value y = valList[2]; + rewriter.replaceOpWithNewOp( + binder.op, resultType, condition, x, y); + return success(); + }); + patterns.onOp( + "Xor", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + Value y; + if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp(binder.op, + resultType, x, y); + return success(); + }); + patterns.onOp( + "Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + if (sizes.size() == 0) { + rewriter.replaceOpWithNewOp(binder.op, + resultType, data); + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList); + return success(); + }); + patterns.onOp( + "Unsqueeze", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // Unlike squeeze where we are able to lower to Torch::PrimsSqueezeOp, + // pytorch does not support torch.unsqueeze to insert multiple new dims. + // discussion can be found here: + // https://github.com/pytorch/pytorch/issues/9410 + // So, for now, we unroll into multiple unsqueezes. + Torch::ValueTensorType resultType; + Value data; + Value axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + if (sizes.size() == 0) { + rewriter.replaceOp(binder.op, data); + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value noneVal = rewriter.create(binder.getLoc()); + Value updatedAxes = rewriter.create( + binder.getLoc(), + axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()), + dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse); + // Sort the list of dims, so we don't run into this situation: + // data.sizes = [2, 3, 4] + // dims = [4, 0] + // index 4 will be invalid to add a singleton dimension because + // data.sizes.size == 3 We have to work with sorted dims to avoid this + // situation. + auto sortIndicesType = axesType.getWithSizesAndDtype( + axesType.getOptionalSizes(), + IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed)); + auto sortOpResult = rewriter.create( + binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero, + cstFalse); + Value result; + auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( + binder.op->getContext()); + // Go through the updated, sorted axes. Do unsqueeze for each dim. + for (int i = 0; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, sortOpResult->getResult(0), + zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + if (sizes[0] == 1) { + result = rewriter.create( + binder.getLoc(), resultType, data, dim); + } else if (i == 0) { + result = rewriter.create( + binder.getLoc(), baseType, data, dim); + } else if (i == sizes[0] - 1) { + result = rewriter.create( + binder.getLoc(), resultType, result, dim); + } else { + result = rewriter.create( + binder.getLoc(), baseType, result, dim); + } + } + rewriter.replaceOp(binder.op, result); + return success(); + }); + patterns.onOp( + "Softmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t axis; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.tensorResultType(resultType)) + return failure(); + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(input.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + + Value noneVal = rewriter.create(binder.getLoc()); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, constAxis, /*dtype=*/noneVal); + return success(); + }); + + patterns.onOp( + "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + float alpha, gamma; + Value operand; + if (binder.tensorOperand(operand) || + binder.f32FloatAttr(alpha, "alpha") || + binder.f32FloatAttr(gamma, "gamma") || + binder.tensorResultType(resultType)) + return failure(); + + Value vAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); + + Value vScale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); + + Value vInputScale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, vAlpha, vScale, vInputScale); + return success(); + }); + patterns.onOp( + "ReduceSum", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + Value noneVal = rewriter.create(binder.getLoc()); + // Deal with case when axes is empty + if (sizes.size() == 1 && sizes[0] == 0) { + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/noneVal, + /*keepdim=*/keepDimsBool, /*dtype=*/noneVal); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); + patterns.onOp( + "ReduceMean", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + SmallVector axesList; + + Value axesVal; + if (!binder.tensorOperandAtIndex(axesVal, 1)) { + Torch::BaseTensorType axesType = + axesVal.getType().cast(); + SmallVector dimList; + SmallVector selectSizes{1}; + auto selType = rewriter.getType( + selectSizes, axesType.getOptionalDtype()); + auto axesTy = dyn_cast(axesVal.getType()); + auto axesShape = axesTy.getSizes(); + + if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + return failure(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + int64_t numAxes = axesShape[0]; + for (int64_t i = 0; i < numAxes; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + binder.getLoc(), selType, axesVal, zero, iv); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + SmallVector axesInts; + if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { + for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axesInts[i])); + axesList.push_back(iv); + } + } + + // deal with case when axes is empty + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(adjustmentInt)); + + // Handle if the axes value is less than zero: + for (int i = 0, s = axesList.size(); i < s; i++) { + Value isNegative = rewriter.create( + binder.getLoc(), axesList[i], zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), axesList[i], finalOffset); + axesList[i] = finalDim; + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + Value noneVal = rewriter.create(binder.getLoc()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); + patterns.onOp( + "ReduceMax", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // AtenAmaxOp allows us to pass a list of dims + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + auto dataTy = cast(data.getType()); + Torch::IntType torchIntTy = rewriter.getType(); + + // If any of the input dims are 0 we set to the upper limit: + if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) && + (llvm::any_of(dataTy.getSizes(), + [](int64_t d) { return d == Torch::kUnknownSize; }) || + keepDims)) { + auto dty = dataTy.getDtype(); + Value scalar; + if (FloatType fpTy = dyn_cast(dty)) { + auto inf = APFloat::getInf(fpTy.getFloatSemantics()); + scalar = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), + inf.convertToDouble())); + } + + if (IntegerType intTy = dyn_cast(dty)) { + auto mx = + intTy.isSigned() + ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); + scalar = rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + mx.getSExtValue())); + } + + llvm::SmallVector fillDims; + for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) { + auto staticDim = resultType.getSizes()[i]; + if (staticDim != Torch::kUnknownSize) { + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(staticDim))); + continue; + } + + Value iv = rewriter.create( + binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, data, iv)); + } + + Value none = rewriter.create(binder.getLoc()); + Value fillDimsList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, fillDimsList, scalar, none, none, none, + none); + return success(); + } + + // Previous version of the operation had the axes as an attribute: + SmallVector axesList; + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Extract the axes values from the axes operand: + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + auto sizes = axesType.getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + // Extract the value of each axes: + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Deal with case when no axes arg is passed but not a noop: + if (axesList.empty()) { + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + for (Value &axes : axesList) { + Value isNegative = + rewriter.create(binder.getLoc(), axes, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); + }); + + patterns.onOp( + "ReduceMin", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // AtenAminOp allows us to pass a list of dims + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + auto dataTy = cast(data.getType()); + Torch::IntType torchIntTy = rewriter.getType(); + + // If any of the input dims are 0 we set to the upper limit: + if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) && + (llvm::any_of(dataTy.getSizes(), + [](int64_t d) { return d == Torch::kUnknownSize; }) || + keepDims)) { + auto dty = dataTy.getDtype(); + Value scalar; + if (FloatType fpTy = dyn_cast(dty)) { + auto inf = APFloat::getInf(fpTy.getFloatSemantics()); + scalar = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), + inf.convertToDouble())); + } + + if (IntegerType intTy = dyn_cast(dty)) { + auto mx = + intTy.isSigned() + ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); + scalar = rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + mx.getSExtValue())); + } + + llvm::SmallVector fillDims; + for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) { + auto staticDim = resultType.getSizes()[i]; + if (staticDim != Torch::kUnknownSize) { + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(staticDim))); + continue; + } + + Value iv = rewriter.create( + binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, data, iv)); + } + + Value none = rewriter.create(binder.getLoc()); + Value fillDimsList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, fillDimsList, scalar, none, none, none, + none); + return success(); + } + + // Previous version of the operation had the axes as an attribute: + SmallVector axesList; + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Extract the axes values from the axes operand: + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + auto sizes = axesType.getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + // Extract the value of each axes: + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Deal with case when no axes arg is passed but not a noop: + if (axesList.empty()) { + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + for (Value &axes : axesList) { + Value isNegative = + rewriter.create(binder.getLoc(), axes, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); + }); + + patterns.onOp("Shape", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + + patterns.onOp("Sinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + + // split with fixed-size parts + // Arguments: + // - input: the tensor to split + // Attributes: + // - axis: the axis along which to split the input + // - num_outputs: the number of outputs to produce + // Outputs: + // - outputs: the produced outputs. Variadic with num_outputs elements. + // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of + // tensors + // so we need to unpack the list + patterns.onOp( + "Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value self; + int64_t axis; + int64_t num_outputs; + if (binder.tensorOperand(self)) + return rewriter.notifyMatchFailure( + binder.op, "Not converting to AtenSplitTensorOp due to input " + "tensor mismatch"); + if (binder.s64IntegerAttr(axis, "axis", 0)) + return rewriter.notifyMatchFailure(binder.op, + "Failed to get axis attribute"); + if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0)) + return rewriter.notifyMatchFailure( + binder.op, "Failed to get num_outputs attribute"); + + auto result0Ty = + binder.op->getResult(0).getType().cast(); + auto selfTy = self.getType().cast(); + + int64_t dim = axis; + if (dim < 0) + dim += selfTy.getSizes().size(); + + // set intermediate shape to the shape of the first result + // if the results are of different shapes + // set the splitted axis to variable shape + llvm::SmallVector intermediateShape(result0Ty.getSizes()); + for (auto result : binder.op->getResultTypes()) { + int64_t d = result.cast().getSizes()[dim]; + intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; + } + + Value dimValue = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + + Value splitSize = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), num_outputs)); + + // TODO: Attempting to use the shape expected by the ONNX mlir as ground + // truth. For now just use dynamic shapes. + auto resultOuterType = + Torch::ListType::get(rewriter.getType( + /*std::optional>=*/intermediateShape, + result0Ty.getOptionalDtype())); + Torch::AtenSplitTensorOp new_op = + rewriter.create( + binder.getLoc(), resultOuterType, self, splitSize, dimValue); + + // the onnx op is variadic with multiple results, but AtenSplitWithSizes + // outputs a list so we need to unpack the list + rewriter.replaceOpWithNewOp( + binder.op, binder.op->getResults().getType(), new_op.getResult()); + + return success(); + }); + + // split with variable parts + // Arguments: + // - input: the tensor to split + // - split: the sizes of the splits to be produced + // Attributes: + // - axis: the axis along which to split the input + // - num_outputs: the number of outputs to produce + // Outputs: + // - outputs: the produced outputs. Variadic with num_outputs elements. + // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of + // tensors + // so we need to unpack the list + patterns.onOp( + "Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value self; + Value split; + int64_t axis; + int64_t num_outputs; + if (binder.tensorOperandAtIndex(self, 0) || + binder.tensorOperandAtIndex(split, 1)) + return rewriter.notifyMatchFailure( + binder.op, "Not converting to AtenSplitWithSizesOp due to input " + "tensor mismatch"); + if (binder.s64IntegerAttr(axis, "axis", 0)) + return rewriter.notifyMatchFailure(binder.op, + "Failed to get axis attribute"); + if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0)) + return rewriter.notifyMatchFailure( + binder.op, "Failed to get num_outputs attribute"); + + auto result0Ty = + binder.op->getResult(0).getType().cast(); + auto selfTy = + cast(binder.op->getOperand(0).getType()); + + int64_t dim = axis; + if (dim < 0) + dim += selfTy.getSizes().size(); + + llvm::SmallVector intermediateShape(result0Ty.getSizes()); + for (auto result : binder.op->getResultTypes()) { + int64_t d = result.cast().getSizes()[dim]; + intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; + } + + Torch::PrimTolistOp splitToList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(rewriter.getType()), split); + + Value dimValue = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + + // TODO: Attempting to use the shape expected by the ONNX mlir as ground + // truth. For now just use dynamic shapes. + auto resultOuterType = + Torch::ListType::get(rewriter.getType( + /*std::optional>=*/intermediateShape, + result0Ty.getOptionalDtype())); + Torch::AtenSplitWithSizesOp new_op = + rewriter.create( + binder.getLoc(), resultOuterType, self, + splitToList.getResult(0), dimValue); + + // the onnx op is variadic with multiple results, but AtenSplitWithSizes + // outputs a list so we need to unpack the list + rewriter.replaceOpWithNewOp( + binder.op, binder.op->getResults().getType(), new_op.getResult()); + + return success(); + }); + + patterns.onOp("Tan", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + + patterns.onOp( + "Transpose", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + auto loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + auto operandType = operand.getType().cast(); + TensorType tensorType = operandType.toBuiltinTensor(); + if (!tensorType || !tensorType.hasRank()) + return failure(); + + // Default permutation is to reverse orders: + int64_t rank = tensorType.getRank(); + llvm::SmallVector reverse(rank); + for (int64_t i = 0; i < rank; ++i) { + reverse[i] = rank - i - 1; + } + + llvm::SmallVector permutations; + if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse))) + return rewriter.notifyMatchFailure(binder.op, + "Failed to obtain permutations"); + + if (static_cast(permutations.size()) != rank) + return rewriter.notifyMatchFailure( + binder.op, "Permutation length does not match operand rank"); + + llvm::SmallVector shape(tensorType.getShape()); + llvm::SmallVector current(rank); + for (int64_t i = 0; i < rank; ++i) { + current[i] = i; + } + + for (auto &dim : permutations) + dim = dim < 0 ? dim + rank : dim; + + // We need to override to the destination if known: + if (resultType.hasSizes()) { + for (int i = 0; i < rank; ++i) { + shape[permutations[i]] = resultType.getSizes()[i]; + } + } + + // Convert dynamic shape dimension: + for (unsigned i = 0; i < shape.size(); i++) { + if (shape[i] == ShapedType::kDynamic) + shape[i] = Torch::kUnknownSize; + } + + for (int64_t i = 0; i < rank; ++i) { + if (current[i] == permutations[i]) + continue; + + int64_t target = i + 1; + for (; target < rank; ++target) { + if (current[target] == permutations[i]) + break; + } + + std::swap(shape[i], shape[target]); + std::swap(current[i], current[target]); + + Value dim0 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + + Value dim1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), target)); + + operand = rewriter.create( + loc, + Torch::ValueTensorType::get(tensorType.getContext(), shape, + operandType.getOptionalDtype()), + operand, dim0, dim1); + } + + rewriter.replaceOp(binder.op, operand); + return success(); + }); + patterns.onOp( + "Slice", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultTorchType; + Value operand, starts, ends; + // Handle if axes are not provided + + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorOperandAtIndex(starts, 1) || + binder.tensorOperandAtIndex(ends, 2) || + binder.tensorResultType(resultTorchType)) { + return failure(); + } + + auto context = rewriter.getContext(); + auto operandTorchTy = operand.getType().cast(); + auto operandTy = + operandTorchTy.toBuiltinTensor().dyn_cast(); + + if (!operandTy) + return rewriter.notifyMatchFailure( + binder.op, + "Expected tensor operator argument to be a ranked tensor type"); + + auto startsTorchTy = starts.getType().cast(); + auto startsTy = + startsTorchTy.toBuiltinTensor().dyn_cast(); + int startSize = startsTy.getDimSize(0); + + auto endsTorchTy = ends.getType().cast(); + auto endsTy = + endsTorchTy.toBuiltinTensor().dyn_cast(); + int endSize = endsTy.getDimSize(0); + auto resultTy = + resultTorchType.toBuiltinTensor().dyn_cast(); + if (!resultTy) + return rewriter.notifyMatchFailure( + binder.op, "Expected result type to be a ranked tensor type"); + + Location loc = binder.getLoc(); + + // Binding `axes` from its arguments or through a default value + Value axes; + if (binder.getNumOperands() >= 4) { + if (binder.tensorOperandAtIndex(axes, 3)) { + return failure(); + } + } + + // Binding `steps` from its arguments or through a default value + Value steps; + if (binder.getNumOperands() >= 5) { + if (binder.tensorOperandAtIndex(steps, 4)) { + return failure(); + } + } else { + // The default `steps` value is a 1d tensor filled with ones with a + // size equal to the size of `starts` and `ends`. + Value none = rewriter.create(loc); + Value sizeStepInput = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize)); + Value sizeStepsInput = rewriter.create( + loc, + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + sizeStepInput); + steps = rewriter.create( + loc, startsTorchTy, sizeStepsInput, none, none, none, none); + } + + if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 && + startSize == endSize)) + return rewriter.notifyMatchFailure( + binder.op, "Expected the rank of starts and ends tensors to be 1 " + "and their dimensions to match"); + + if (axes) { + auto axesTorchTy = axes.getType().cast(); + auto axesTy = + axesTorchTy.toBuiltinTensor().dyn_cast(); + int64_t numAxes = axesTy.getDimSize(0); + + if (!(axesTy && numAxes == endSize)) + return rewriter.notifyMatchFailure( + binder.op, "Axes should be the same size of starts and ends"); + } + + auto stepsTy = steps.getType() + .cast() + .toBuiltinTensor() + .dyn_cast(); + + if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0))) + return rewriter.notifyMatchFailure( + binder.op, "Steps should be the same size of starts and ends"); + + Value zero = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + auto select = [&](Value v, Value k) -> Value { + auto ty = v.getType().cast(); + auto sel = rewriter.create( + loc, + Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, + ty.getOptionalDtype()), + v, zero, k); + Value item = rewriter.create( + loc, rewriter.getType(), sel); + return item; + }; + + llvm::SmallVector intermediateShape(operandTy.getShape()); + for (int i = 0, s = operandTy.getRank(); i < s; ++i) { + if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) + intermediateShape[i] = -1; + if (intermediateShape[i] == ShapedType::kDynamic) + intermediateShape[i] = Torch::kUnknownSize; + } + auto intermediateType = Torch::ValueTensorType::get( + context, intermediateShape, resultTorchType.getOptionalDtype()); + for (int i = 0; i < endSize; ++i) { + + Value k = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value kTensor = rewriter.create( + loc, + Torch::ValueTensorType::get( + context, ArrayRef{1}, + rewriter.getIntegerType(64, /*signed*/ 1)), + k); + + Value start = select(starts, kTensor); + Value end = select(ends, kTensor); + Value axis = axes ? select(axes, kTensor) : k; + Value step = select(steps, kTensor); + + auto sliceType = intermediateType; + sliceType = i == (endSize - 1) ? resultTorchType : sliceType; + operand = rewriter.create( + loc, sliceType, operand, axis, start, end, step); + } + + rewriter.replaceOp(binder.op, operand); + return success(); + }); + patterns.onOp( + "Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value shape; + int64_t allowzero; + if (binder.tensorOperands(data, shape) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(allowzero, "allowzero", 0)) + return failure(); + + // If the result shape is static then we can create a result shape list + // directly using the result shape values (integers). + if (resultType.hasSizes()) { + bool hasStaticShape = resultType.areAllSizesKnown(); + ArrayRef resultShapeInt = resultType.getSizes(); + if (hasStaticShape) { + SmallVector resultShape; + for (int64_t dim : resultShapeInt) { + resultShape.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dim))); + } + Value resultShapeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + resultShape); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, resultShapeList); + return success(); + } + } + + Torch::BaseTensorType shapeType = + shape.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + auto shapeSizes = + dyn_cast(shape.getType()).getSizes(); + auto dataSizes = + dyn_cast(data.getType()).getSizes(); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + if (allowzero == 0) { + // convert shape (tensor) into torch int list while dealing with zero + // vals + for (int i = 0; i < shapeSizes[0]; i++) { + // Go through the shape list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with zero axis values: replace with original dim value in + // input + Value isZero = + rewriter.create(binder.getLoc(), dim, zero); + isZero = + rewriter.create(binder.getLoc(), isZero); + Value adjustment; + int64_t inputDimsSize = dataSizes.size(); + if (i < inputDimsSize) { + adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dataSizes[i])); + } + // Will never have a 0 in the shape tensor input at an index out of + // bounds of original input dims Therefore, no need to adjust + else { + adjustment = zero; + } + Value finalOffset = rewriter.create( + binder.getLoc(), isZero, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList); + return success(); + } + // convert axes (tensor) into torch int list + for (int i = 0; i < shapeSizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp(binder.op, resultType, + data, dimValueList); + return success(); + }); + patterns.onOp( + "ReduceProd", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // ReduceProd allows us to pass a list of dims but AtenProdDimIn only + // allow one dim as input. + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + auto dataTy = cast(data.getType()); + Torch::IntType torchIntTy = rewriter.getType(); + + if (!resultType.hasSizes() || !resultType.areAllSizesKnown() || + !dataTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure( + binder.op, + "Expected the input and result type to have known sizes"); + + int64_t rank = dataTy.getSizes().size(); + SmallVector axesList; + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + + // Previous version of the operation had the axes as an attribute: + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Handle cases that axes are explicitly specified. + // Extract the axes values from the axes operand. + // This really shouldn't happen but it helps pass weird tests. + // TODO: Derive the chosen axes from the data type and final result type + // instead of using the dynamic axes at operand[1]. + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + axes.getType().cast(); + auto sizes = axesType.getSizes(); + for (int i = 0; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + binder.getLoc(), + axesType.getWithSizesAndDtype(llvm::SmallVector{1}, + axesType.getOptionalDtype()), + axes, zero, selectIndex); + Value dim = rewriter.create(binder.getLoc(), + torchIntTy, extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + // When axes is empty and noop_with_empty_axes is set to true, input + // tensor will not be reduced, and the output tensor would be + // equivalent to input tensor. + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Handle case when no axes arg is passed but not a noop: + // Manually set positive axis to all dims. + if (axesList.empty()) { + for (int i = 0; i < rank; i++) { + Value dimValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + axesList.push_back(dimValue); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); + for (Value &axes : axesList) { + Value isNegative = + rewriter.create(binder.getLoc(), axes, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); + } + + // Handle multiple axes case: + // ReduceProd on each dim, always set keepDimsBool == True to avoid + // segfault. + Value trueVal = + rewriter.create(binder.getLoc(), true); + Value noneVal = rewriter.create(binder.getLoc()); + SmallVector intermediateShape(rank, Torch::kUnknownSize); + Value dataReduceProd = data; + for (int i = 0, numAxes = axesList.size(); i < numAxes; i++) { + auto axis = axesList[i]; + if (keepDims && i == numAxes - 1) { + dataReduceProd = rewriter.create( + binder.getLoc(), + dataTy.getWithSizesAndDtype(resultType.getSizes(), + dataTy.getOptionalDtype()), + dataReduceProd, axis, trueVal, noneVal); + rewriter.replaceOp(binder.op, dataReduceProd); + return success(); + } + Type resultTyReduceProd = dataTy.getWithSizesAndDtype( + ArrayRef(intermediateShape), dataTy.getOptionalDtype()); + dataReduceProd = rewriter.create( + binder.getLoc(), resultTyReduceProd, dataReduceProd, axis, + trueVal, noneVal); + } + + // Derived the final shape of the tensor after prod loop of each axis. + SmallVector dataReduceProdSize; + auto dataSize = dataTy.getSizes(); + auto resultTypeSizes = resultType.getSizes(); + if (!keepDims) { + // Handle the keepDimsBool == False case: + // 2 point algorithm to derive the static shape after prod loop. + int j = 0; + for (int i = 0; i < rank; i++) { + if (resultTypeSizes.size() && dataSize[i] == resultTypeSizes[j]) { + dataReduceProdSize.push_back(resultTypeSizes[i]); + j++; + continue; + } + dataReduceProdSize.push_back(1); + } + } + + // Handle the keepDimsBool == False case: + // Reshape the prod loop result to the final result shape. + SmallVector dataReduceProdShape; + for (auto dim : dataReduceProdSize) + dataReduceProdShape.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dim))); + Value dataReduceProdShapeList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + dataReduceProdShape); + rewriter.replaceOpWithNewOp( + binder.op, resultType, dataReduceProd, dataReduceProdShapeList); + return success(); + }); + patterns.onOp( + "Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // ONNX.Range(start, limit, delta) -- limit is exclusive + + Torch::ValueTensorType resultType; + Value start, limit, delta; + auto loc = binder.getLoc(); + Value none = rewriter.create(loc); + if (binder.tensorOperandAtIndex(start, 0) || + binder.tensorOperandAtIndex(limit, 1) || + binder.tensorOperandAtIndex(delta, 2) || + binder.tensorResultType(resultType)) + return failure(); + + // Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric + // Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example + // type of start, limit, delta can be one of: double, float, int16, + // int32, int64 Assuming start, limit and delta to be same type (could + // they be different?) + Torch::BaseTensorType startTensorType = + start.getType().cast(); + bool isFloatDType = startTensorType.getDtype().isF64() || + startTensorType.getDtype().isF32(); + bool isIntDType = startTensorType.getDtype().isInteger(16) || + startTensorType.getDtype().isInteger(32) || + startTensorType.getDtype().isInteger(64); + if (!isFloatDType && !isIntDType) { + return rewriter.notifyMatchFailure( + binder.op, "Expected the start, limit, delta to be one of " + "double, float, int16, int32, int64"); + } + Value scalarStart, scalarLimit, scalarDelta; + if (isFloatDType) { + scalarStart = getItemOp(binder, rewriter, start); + scalarLimit = getItemOp(binder, rewriter, limit); + scalarDelta = getItemOp(binder, rewriter, delta); + } else { + scalarStart = getItemOp(binder, rewriter, start); + scalarLimit = getItemOp(binder, rewriter, limit); + scalarDelta = getItemOp(binder, rewriter, delta); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none, + none, none, none); + return success(); + }); + patterns.onOp( + "Size", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + auto loc = binder.getLoc(); + auto &op = binder.op; + auto operandTy = cast(operand.getType()); + + if (!operandTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "input rank unknown"); + + llvm::SmallVector dims; + int64_t rank = operandTy.getSizes().size(); + for (int i = 0; i < rank; ++i) { + auto iv = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + Value dim = rewriter.create( + loc, rewriter.getType(), operand, iv); + dims.push_back(dim); + } + + Value cstFalse = rewriter.create(loc, false); + Value none = rewriter.create(loc); + + if (dims.empty()) { + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp( + op, resultType, one, none, none, cstFalse); + return success(); + } + + Value prod = dims[0]; + for (int i = 1, s = dims.size(); i < s; ++i) + prod = rewriter.create(loc, prod, dims[i]); + + rewriter.replaceOpWithNewOp( + op, resultType, prod, none, none, cstFalse); + return success(); + }); + patterns.onOp( + "Tile", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + Value repeatDims; + if (binder.tensorOperands(operand, repeatDims) || + binder.tensorResultType(resultType)) + return failure(); + + // convert repeatDims tensor to list of ints + auto repeatDimsSizes = + dyn_cast(repeatDims.getType()).getSizes(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Torch::BaseTensorType shapeType = + repeatDims.getType().cast(); + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < repeatDimsSizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, repeatDims, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + operand, dimValueList); + return success(); + }); + patterns.onOp( + "Topk", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType Values_type, Indices_type; + Value X, K; + int64_t axis; + bool largest, sorted; + if (binder.tensorOperandAtIndex(X, 0) || + binder.tensorOperandAtIndex(K, 1) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.s64BoolAttr(largest, "largest", true) || + binder.s64BoolAttr(sorted, "sorted", true) || + binder.tensorResultTypeAtIndex(Values_type, 0) || + binder.tensorResultTypeAtIndex(Indices_type, 1)) + return failure(); + std::optional maybeRank = Torch::getTensorRank(X); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + axis = Torch::toPositiveDim(axis, rank); + Value cstAxis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value cstLargest = + rewriter.create(binder.getLoc(), largest); + Value cstSorted = + rewriter.create(binder.getLoc(), sorted); + rewriter.replaceOpWithNewOp( + binder.op, Values_type, Indices_type, X, K, cstAxis, cstLargest, + cstSorted); + return success(); + }); + patterns.onOp("Sign", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); +} diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp new file mode 100644 index 000000000000..ef3da8b3b3fa --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -0,0 +1,49 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +Value mlir::torch::onnx_c::createConstantIntList( + OpBinder binder, ConversionPatternRewriter &rewriter, + SmallVector cstInput) { + SmallVector cstValue; + for (int64_t i : cstInput) { + cstValue.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstValue); +} + +Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { + Torch::ValueTensorType tty = dyn_cast(ty); + if (!tty) + return nullptr; + + auto ctx = ty.getContext(); + Type dty = tty.getDtype(); + + if (dty.isUnsignedInteger(8)) + dty = Torch::QUInt8Type::get(ctx); + if (dty.isSignedInteger(8)) + dty = Torch::QInt8Type::get(ctx); + if (dty.isSignedInteger(32)) + dty = Torch::QInt32Type::get(ctx); + + if (!dty) + return nullptr; + return Torch::ValueTensorType::get(ctx, tty.getOptionalSizes(), dty); +} diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index e1e53acb2363..0ca2d108a5e3 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -43,7 +43,8 @@ class ConvertAtenDimOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rank = rewriter.create(op->getLoc(), adaptor.getSelf()); + auto rank = + rewriter.create(op->getLoc(), adaptor.getSelf()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); @@ -74,7 +75,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern { matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(op, adaptor.getA(), adaptor.getB()); + rewriter.template replaceOpWithNewOp(op, adaptor.getA(), + adaptor.getB()); return success(); } }; @@ -112,10 +114,10 @@ class ConvertAtenDivIntOp : public OpConversionPattern { typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value a = - convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type()); - Value b = - convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type()); + Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(), + rewriter.getF64Type()); + Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(), + rewriter.getF64Type()); rewriter.replaceOpWithNewOp(op, a, b); return success(); } @@ -176,15 +178,16 @@ class ConvertTorchTensorLiteralOp unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); Type builtinTensorElemTy = IntegerType::get(context, bitWidth); auto shapedType = - RankedTensorType::get(type.getShape(), builtinTensorElemTy); + RankedTensorType::get(type.getShape(), builtinTensorElemTy); auto rawData = elements.getRawData(); - DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( - shapedType, rawData); + DenseElementsAttr newAttr = + DenseElementsAttr::getFromRawBuffer(shapedType, rawData); rewriter.replaceOpWithNewOp(op, newAttr); return success(); } } - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = + op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = @@ -360,7 +363,8 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern { // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToArith : public ConvertTorchToArithBase { +class ConvertTorchToArith + : public ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -439,9 +443,11 @@ class ConvertTorchToArith : public ConvertTorchToArithBase typeConverter, context); patterns.add>( typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 4eb02215a8bf..e4bf1886bb91 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -51,6 +51,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); + Value negone = rewriter.create(loc, -1); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) @@ -73,40 +74,507 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, torchTypeEnd.getType().isa()) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); - int64_t step; - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { - if (!op.getStep().getType().template isa()) - return op->emitError("unimplemented: step is not constant"); - step = 1; - } - + Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep()); Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, builtinTypeStart, zero, dimSize); - Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, - dimSize, dimSize); - // end >= start ? end : start - Value endSgeStart = rewriter.create( - loc, arith::CmpIPredicate::sge, end, start); - end = rewriter.create(loc, endSgeStart, end, start); - Value stepIndex = rewriter.create(loc, step); + // We cannot use to positive valid dim as for negative strides we need to + // clamp to `-1` so that the full tensor bounds are available: + Value end = builtinTypeEnd; + if (torchTypeEnd.getType().isa()) { + end = dimSize; + } else { + end = castIntToIndex(rewriter, loc, end); + Value endcmp = rewriter.create( + loc, arith::CmpIPredicate::slt, end, zero); + Value endadd = rewriter.create(loc, end, dimSize); + end = rewriter.create(loc, endcmp, endadd, end); + endcmp = rewriter.create(loc, arith::CmpIPredicate::slt, end, + zero); + end = rewriter.create(loc, endcmp, negone, end); + endcmp = rewriter.create(loc, arith::CmpIPredicate::sgt, end, + dimSize); + end = rewriter.create(loc, endcmp, dimSize, end); + } // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); Value len = rewriter.create(loc, end, start); + + // We check the difference between start and end to determine the total size: + Value stepcmp = rewriter.create(loc, arith::CmpIPredicate::sge, + stepIndex, zero); + Value stepsign = rewriter.create(loc, stepcmp, one, negone); Value resultSize = rewriter.create(loc, len, stepIndex); - resultSize = rewriter.create(loc, resultSize, one); + resultSize = rewriter.create(loc, resultSize, stepsign); resultSize = rewriter.create(loc, resultSize, stepIndex); + + // Clamp the size to [0, ...]: + Value szcmp = rewriter.create(loc, arith::CmpIPredicate::slt, + resultSize, zero); + resultSize = rewriter.create(loc, szcmp, zero, resultSize); resultShape[dim] = resultSize; strides.resize(inputType.getRank(), one); offsets.resize(inputType.getRank(), zero); offsets[dim] = start; - strides[dim] = rewriter.create(loc, strides[dim], stepIndex); + strides[dim] = stepIndex; return success(); } +// Example: +// input = tensor([[[0., 1., 2., 3.], +// [4., 5., 6., 7.]]]) +// torch.ops.aten.reflection_pad1d(input, (3,1)); +// padding_left = 3, +// padding_right = 1 +// output = tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], +// [7., 6., 5., 4., 5., 6., 7., 6.]]]) +// Checks: 1) Each of padding_left and padding_right must be non-negative and +// less than the size of the last dimension. +// Implementation: a) Construct a result tensor of +// shape of input tensor except for the last dimension. +// The last dimension of the result tensor should be last +// dimension of input tensor + left padding size + right +// padding size. Initialize result tensor to all zeros +// b) Setup affine map to take slice from input tensor of size +// left padding starting from +// second column onwards as first column is reflection +// boundary +// c) Reflect the affine map to have resultant slice reflected +// d) Take the slice and write from begining in result tensor +// e) write the original tensor next into result tensor +// f) Setup affine map to take slice from input tensor of right +// padding size ending +// at second last column as last column is reflection +// boundary for right padding +// g) Reflect the affine map to have resultant slice reflected +// h) Take the slice and write from left padding size + orignal +// tensor last dim size +// into result tensor +// Uses the ideas/code used for AtenReflectionPad2dOp +namespace { +class ConvertAtenReflectionPad1dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReflectionPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only constant int padding range is supported"); + + MLIRContext *context = rewriter.getContext(); + Location loc = op.getLoc(); + + // Lambda Unitility Functions + // Create an Integer expression of x + y + auto createIAdd = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + // Create an integer expression of x - y + auto createISub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + enum PadLocation { PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER = 2 }; + + Value input = adaptor.getSelf(); + Type indexType = rewriter.getIndexType(); + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + auto inputType = llvm::cast(input.getType()); + auto outputType = llvm::cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); + int64_t lastDim = numDims - 1; + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, + // inputShape[2] will give 4 + + Value tileWidth[3], extractOffset[3], insertOffset[3]; + + tileWidth[PAD_LEFT] = + getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); + tileWidth[PAD_RIGHT] = + getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); + tileWidth[PAD_CENTER] = lastDimSize; + + extractOffset[PAD_LEFT] = one; + // The offset for the right hand padding "bar" is: + // [right] lastDimSize - (tileWidth[PAD_RIGHT] + one) + extractOffset[PAD_RIGHT] = + createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); + extractOffset[PAD_CENTER] = zero; + + insertOffset[PAD_LEFT] = zero; + insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]); + insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT]; + + SmallVector resultShape{inputShape}; + // Result's last dimension will have size: + // lastDimSize + left padding size + right padding size + resultShape[lastDim] = + createIAdd(resultShape[lastDim], + createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, + inputType.getElementType()); + + // Helper to reflect/reverse the i-th dimension of an affine map without + // symbols. This only works if applied on a tensor for which the + // corresponding dimension has a statically known size + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, + int64_t size) { + AffineExpr d = map.getResult(i); + return map.replace(d, size - d - 1, numDims, + 0); // left reflect for (3,1) on input shape (1,2,4). + // size = 3, lastDim=2, numDims=3 + }; + + SmallVector iteratorTypes{ + numDims, utils::IteratorType::parallel}; + auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); + SmallVector allOneStrides(numDims, one); + + auto addTileToResult = [&](PadLocation padPosition) { + // Create the tile by extracting a slice from the input tensor. + SmallVector extractShape{inputShape}; + extractShape[lastDim] = tileWidth[padPosition]; + SmallVector extractOffsets(numDims, zero); + extractOffsets[lastDim] = extractOffset[padPosition]; + Value tile = rewriter.create( + loc, input, extractOffsets, extractShape, allOneStrides); + + auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); + // Setup the affine map function to resverse the tile along the horizontal + // for left and right slices + if (padPosition < PAD_CENTER) { + inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); + // Take reflected slice as per inputMap + tile = rewriter + .create( + loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); + } + // Insert the tile in the resultTensor + SmallVector insertOffsets(numDims, zero); + insertOffsets[lastDim] = insertOffset[padPosition]; + resultTensor = rewriter.create( + loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + }; + + if (padInts[PAD_LEFT] > 0) + addTileToResult(PAD_LEFT); + if (padInts[PAD_RIGHT] > 0) + addTileToResult(PAD_RIGHT); + addTileToResult(PAD_CENTER); + + rewriter.replaceOpWithNewOp(op, outputType, resultTensor); + return success(); + } +}; +} // namespace + +namespace { + +// Lower the aten.reflection.pad_2d operator into a sequence of +// tensor.extract_slice, linalg.generic, and tensor_insert_slice +// operations. + +// To understand the lowering, consider this pytorch example: +// +// >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]]) +// >>> t +// tensor([[[1., 2., 3.], +// [4., 5., 6.], +// [7., 8., 9.]]]) +// >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2]) +// tensor([[[5., 4., 5., 6., 5., 4.], +// [2., 1., 2., 3., 2., 1.], +// [5., 4., 5., 6., 5., 4.], +// [8., 7., 8., 9., 8., 7.], +// [5., 4., 5., 6., 5., 4.], +// [2., 1., 2., 3., 2., 1.]]]) +// +// The result can be subdivided into "tiles" corresponding to either +// the input tensor (in the center) or slices of the input tensor +// whose width and height is determined by the padding sizes and which +// are reflected through the side of the central input tensor that +// they touch. +// In the example above, the tiles are: +// top left: [[5]] +// top center: [[4,5,6]] +// top right: [[5,4]] +// center left [[2,1],[5,4],[8,7]] +// center: copy of the input tensor +// center right: [[2,1],[5,4],[8,7]] +// bottom left: [[5,4],[2,1]] +// center bottom: [[2,3,2]] +// center right: [[2,1]] +// +// The lowering uses a tensor.extract_slice operation to create each tile, +// a linalg.generic for the reflection, and a tensor.insert_slice to +// insert the tile in the resulting tensor. +class ConvertAtenReflectionPad2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + + Location loc = op.getLoc(); + // Some generic helper functions for creating arithmetic operations. + auto createAdd = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + auto createAdds = [&](std::initializer_list values) { + assert(values.size() >= 2); + return std::accumulate(values.begin() + 1, values.end(), data(values)[0], + createAdd); + }; + + auto createSub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + auto createSubs = [&](std::initializer_list values) { + assert(values.size() >= 2); + return std::accumulate(values.begin() + 1, values.end(), data(values)[0], + createSub); + }; + + // Enums for specifying the coordinates of a tile. An "h" prefix + // is used to stand for "horizontal" and "v" for "vertical" + // throughout. + enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 }; + enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 }; + + // Helper functions for obtaining information about the operator's + // padding arguments. + auto getHPadArgument = [&](PadHLoc l) { + assert(l < HCENTER); + return padInts[l]; + }; + + auto getVPadArgument = [&](PadVLoc l) { + assert(l < VCENTER); + return padInts[2 + l]; + }; + + auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) { + if (!(h == HCENTER || getHPadArgument(h) > 0)) + return false; + if (!(v == VCENTER || getVPadArgument(v) > 0)) + return false; + + return true; + }; + + Value input = adaptor.getSelf(); + MLIRContext *context = rewriter.getContext(); + auto inputType = llvm::cast(input.getType()); + auto outputType = llvm::cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + unsigned numDims = inputType.getRank(); + + assert(numDims >= 2 && "Not enough input dimensions"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + int64_t hDim = numDims - 1; + int64_t vDim = numDims - 2; + Value hDimSize = inputShape[hDim]; + Value vDimSize = inputShape[vDim]; + + assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && + "Left padding too large"); + assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && + "Right padding too large"); + assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && + "Top padding too large"); + assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && + "Bottom padding too large"); + + Type indexType = rewriter.getIndexType(); + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + + Value tileWidth[3]; + tileWidth[HCENTER] = hDimSize; + for (auto h : {LEFT, RIGHT}) + tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType); + + Value tileHeight[3]; + tileHeight[VCENTER] = vDimSize; + for (auto v : {TOP, BOTTOM}) + tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType); + + // Helper to reflect/reverse the i-th dimension of an affine map + // without symbols. This only works if applied on a tensor + // for which the corresponding dimension has a statically + // known size which is good enough since we only apply + // it to reflect the padding slices. + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, + int64_t size) { + AffineExpr d = map.getResult(i); + return map.replace(d, size - d - 1, numDims, 0); + }; + + // Create output shape and tensor + SmallVector resultShape{inputShape}; + resultShape[vDim] = + createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]}); + resultShape[hDim] = + createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]}); + + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, + inputType.getElementType()); + + // Construction of the tiles + + // Example: central left tile + // + // Let m the width of the left padding as returned by getHPadargument(LEFT) + // and n the size of the input tensor's "horizontal" dimension, i.e. + // hDimSize. Assume that the subtensor of the input tensor in the relevant + // (i.e. last two) dimensions is: + // + // x_1,1 x_1,2 ... x_1,m + // x_2,1 x_2,2 ... x_2,m + // . + // . + // . + // x_n,1 x_n,2 ... x_n,m + // + // The padding tile consists of the columns 2, ..., m + 1 + // of the input in reverse order. The first column gets + // skipped because this is the column through which the + // reflection happens. + // + // x_1,m x_1,m-1 ... x_1,2 + // x_2,m x_1,m-1 ... x_2,2 + // . + // . + // . + // x_n,m x_n,m-1 ... x_n,2 + // + // The tile will be inserted to the left of the copy of the input tensor + // in the output tensor, i.e. with horizontal offset 0. + // The top padding determines the vertical offset. + + // Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through + // two sides, i.e. their columns and rows must be reversed. + + // Setup information about the tiles + + // Compute the offsets for extracting the slice from the + // input. We need to skip the row or column through which + // the tile should be reflected, if any (none for the center tile). + Value extractHOffset[3]; + extractHOffset[LEFT] = one; + extractHOffset[HCENTER] = zero; + extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one}); + + Value extractVOffset[3]; + extractVOffset[TOP] = one; + extractVOffset[VCENTER] = zero; + extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one}); + + // Compute the horizontal and vertical offsets for inserting + // the tiles in the resultTensor. + Value insertHOffset[3]; + insertHOffset[LEFT] = zero; + insertHOffset[HCENTER] = tileWidth[LEFT]; + insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]); + + Value insertVOffset[3]; + insertVOffset[TOP] = zero; + insertVOffset[VCENTER] = tileHeight[TOP]; + insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]); + + auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; }; + auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; }; + + SmallVector iteratorTypes{ + numDims, utils::IteratorType::parallel}; + auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); + SmallVector allOneStrides(numDims, one); + + auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) { + // Create the tile by extracting a slice from the input tenor. + SmallVector extractShape{inputShape}; + extractShape[hDim] = tileWidth[horizontalPos]; + extractShape[vDim] = tileHeight[verticalPos]; + + SmallVector extractOffsets(numDims, zero); + extractOffsets[hDim] = extractHOffset[horizontalPos]; + extractOffsets[vDim] = extractVOffset[verticalPos]; + + Value tile = rewriter.create( + loc, input, extractOffsets, extractShape, allOneStrides); + + // Reverse the tile along the horizontal, vertical, or both + // dimensions. + auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); + if (shouldHReflect(horizontalPos)) { + inputMap = + reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos)); + } + if (shouldVReflect(verticalPos)) { + inputMap = + reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos)); + } + + tile = rewriter + .create( + loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); + + // Insert the tile in the resultTensor. + SmallVector insertOffsets(numDims, zero); + insertOffsets[hDim] = insertHOffset[horizontalPos]; + insertOffsets[vDim] = insertVOffset[verticalPos]; + + resultTensor = rewriter.create( + loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + }; + + for (auto v : {TOP, BOTTOM, VCENTER}) + for (auto h : {LEFT, RIGHT, HCENTER}) + if (shouldCreateTile(v, h)) + createTile(v, h); + + rewriter.replaceOpWithNewOp(op, outputType, resultTensor); + + return success(); + } +}; +} // namespace + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -170,6 +638,68 @@ class ConvertAtenFlattenUsingIntsOp }; } // namespace +// Lower aten.unflatten.int into tensor.expand_shape +namespace { +class ConvertAtenUnflattenIntOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenUnflattenIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + BaseTensorType outputTensorType = op.getType().cast(); + if (!outputTensorType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "unimplemented: output must have known sizes"); + + std::optional maybeRank = getTensorRank(self); + if (!maybeRank) + return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); + auto inputTensorType = self.getType().cast(); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "Expected input type having sizes"); + } + int inputRank = inputTensorType.getSizes().size(); + int outputRank = outputTensorType.getSizes().size(); + + int64_t dimInt; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constants"); + + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + + auto sizesOp = op.getSizes().getDefiningOp(); + int numSizes = sizesOp.getNumOperands(); + + SmallVector reassociations(inputRank); + if (inputRank > 0) { + for (int i = 0; i < dimInt; ++i) + reassociations[i].push_back(i); + + for (int i = 0; i < numSizes; ++i) + reassociations[dimInt].push_back(i + dimInt); + + for (int i = dimInt + numSizes; i < outputRank; ++i) + reassociations[i - numSizes + 1].push_back(i); + } + + auto expandTy = getTypeConverter()->convertType(outputTensorType); + auto expand = rewriter + .create( + loc, expandTy, adaptor.getSelf(), reassociations) + .getResult(); + rewriter.replaceOp(op, expand); + return success(); + } +}; +} // namespace + namespace { /// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to /// one `linalg.TensorExpandShape` op for all expanded dimensions and one @@ -327,14 +857,23 @@ class ConvertAtenViewOp : public OpConversionPattern { auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); - if (resultRank == 0) - return rewriter.notifyMatchFailure(op, - "result shape of rank 0 is invalid"); + if (resultRank == 0) { + rewriter + .replaceOpWithNewOp( + op, resultType, input, ArrayRef()) + .getResult(); + return success(); + } - // TODO: add support for case inputRank 0 expanded to size 1 - if (inputRank == 0) - return rewriter.notifyMatchFailure( - op, "unimplemented: input rank 0 is not supported"); + if (inputRank == 0) { + llvm::SmallVector outshape(resultRank, 1); + auto expandTy = + RankedTensorType::get(outshape, resultType.getElementType()); + Value expand = rewriter.create( + op.getLoc(), expandTy, input, ArrayRef()); + rewriter.replaceOpWithNewOp(op, resultType, expand); + return success(); + } // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. @@ -553,6 +1092,7 @@ class ConvertAtenViewOp : public OpConversionPattern { return success(); } + // TODO: audit possibility of sparsity on these tensors Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( @@ -580,6 +1120,7 @@ class ConvertAtenViewOp : public OpConversionPattern { intermediateShape.push_back(sum); } + // TODO: audit possibility of sparsity on these tensor Type intermediateResultType = RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape), resultType.getElementType()); @@ -1030,13 +1571,6 @@ class ConvertAtenCatOp : public OpConversionPattern { RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); - - auto outElemType = newResultType.getElementType(); - for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i] = torch_to_linalg::convertTensorToElementType( - rewriter, loc, tensors[i], outElemType); - } - int rank = newResultType.getRank(); Value dimValue = op.getDim(); int64_t dim; @@ -1046,48 +1580,25 @@ class ConvertAtenCatOp : public OpConversionPattern { if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector offsets, sizes, strides; - sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); - - for (int i = 0; i < rank; ++i) - sizes.push_back(rewriter.createOrFold(loc, tensors[0], i)); - - // Calculate the size of the `dim` result dimension by adding the dim size - // of each tensor together. - Value resultDimSize = sizes[dim]; - - Value dimIndex = rewriter.createOrFold( - loc, rewriter.getIndexAttr(dim)); - for (auto tensor : ArrayRef(tensors).drop_front()) { - auto size = rewriter.createOrFold(loc, tensor, dimIndex); - resultDimSize = - rewriter.createOrFold(loc, resultDimSize, size); + auto outElemType = newResultType.getElementType(); + for (size_t i = 0; i < tensors.size(); ++i) { + auto inputType = cast(tensors[i].getType()); + if (inputType.getElementType() != outElemType) { + tensors[i] = torch_to_linalg::convertTensorToElementType( + rewriter, loc, tensors[i], outElemType); + } } - sizes[dim] = resultDimSize; - - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) - return v; - return op.getValue(); - }; - Value result = rewriter.create( - loc, getAsOpFoldResult(sizes), newResultType.getElementType()); + llvm::SmallVector filteredTensors; for (auto tensor : tensors) { - SmallVector sizes = getTensorSizes(rewriter, loc, tensor); - result = rewriter.createOrFold( - loc, tensor, result, - llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), - llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), - llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); - offsets[dim] = - rewriter.createOrFold(loc, offsets[dim], sizes[dim]); + auto inputType = cast(tensor.getType()); + if (inputType.getDimSize(dim) != 0) { + filteredTensors.push_back(tensor); + } } - rewriter.replaceOpWithNewOp(op, newResultType, result); + rewriter.replaceOpWithNewOp(op, newResultType, dim, + filteredTensors); return success(); } }; @@ -1269,6 +1780,7 @@ class ConvertAtenSliceScatterOp auto srcType = src.getType().cast(); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); + // TODO: audit possibility of sparsity on these tensor auto abstractSrcType = RankedTensorType::get( makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); Value abstractSrc = @@ -1446,13 +1958,155 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenDiagonalOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenDiagonalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + int64_t offset; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + return rewriter.notifyMatchFailure(op, "offset must be constant"); + int64_t dim1; + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return rewriter.notifyMatchFailure(op, "dim1 must be constant"); + int64_t dim2; + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) + return rewriter.notifyMatchFailure(op, "dim2 must be constant"); + + Value inputMatrix = adaptor.getSelf(); + RankedTensorType inputType = inputMatrix.getType().cast(); + int64_t inputRank = inputType.getRank(); + + if (inputRank < 2) + return rewriter.notifyMatchFailure( + op, "input must have at least two dimensions"); + int64_t outputRank = inputRank - 1; + + dim1 = toPositiveDim(dim1, inputRank); + if (!isValidDim(dim1, inputRank)) + return rewriter.notifyMatchFailure(op, "dim1 out of range"); + dim2 = toPositiveDim(dim2, inputRank); + if (!isValidDim(dim2, inputRank)) + return rewriter.notifyMatchFailure(op, "dim2 out of range"); + if (dim1 == dim2) + return rewriter.notifyMatchFailure( + op, "diagonal dimensions cannot be identical"); + + Type elementType = inputType.getElementType(); + RankedTensorType outputType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + Location loc = op.getLoc(); + + Value dim1Size, dim2Size; + dim1Size = getDimOp(rewriter, loc, inputMatrix, dim1); + dim2Size = getDimOp(rewriter, loc, inputMatrix, dim2); + + // compute the length of the diagonal with possible offset + // if the offset is very large or very small, diagSize=0 and an empty tensor + // is returned + Value indexZero = rewriter.create(loc, 0); + Value indexMinusOne = rewriter.create(loc, -1); + Value indexOffset = rewriter.create(loc, offset); + Value offsetIsNegative = rewriter.create( + loc, arith::CmpIPredicate::sle, indexOffset, indexZero); + Value sizeForNegativeOffset = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, dim1Size, indexOffset), + dim2Size), + indexZero); + Value sizeForPositiveOffset = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, dim2Size, indexOffset), + dim1Size), + indexZero); + Value diagSize = rewriter.create( + loc, offsetIsNegative, sizeForNegativeOffset, sizeForPositiveOffset); + + // depending on its sign, the offset affects only the row or column indices + // of the diagonal + Value diagStart1 = rewriter.create( + loc, offsetIsNegative, + rewriter.create(loc, indexOffset, indexMinusOne), + indexZero); + Value diagStart2 = rewriter.create(loc, offsetIsNegative, + indexZero, indexOffset); + + SmallVector outputDims; + for (auto i = 0; i < inputRank; i++) { + if (!(i == dim1 || i == dim2)) + outputDims.push_back(getDimOp(rewriter, loc, inputMatrix, i)); + } + outputDims.push_back(diagSize); + + Value outputMatrix = rewriter.create( + loc, getAsOpFoldResult(outputDims), elementType); + + SmallVector indexingMaps = { + AffineMap::getMultiDimIdentityMap(outputRank, rewriter.getContext())}; + SmallVector iteratorTypes( + outputRank, utils::IteratorType::parallel); + + auto diagonal = + rewriter + .create( + loc, outputMatrix.getType(), ValueRange{}, outputMatrix, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector diagIndices; + Value indexOnDiag = + b.create(loc, outputRank - 1); + Value dim1Index = + b.create(loc, indexOnDiag, diagStart1); + Value dim2Index = + b.create(loc, indexOnDiag, diagStart2); + + // specify at which input indices the diagonal values are + // extracted + for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) { + if (indIn == dim1) + diagIndices.push_back(dim1Index); + else if (indIn == dim2) + diagIndices.push_back(dim2Index); + else { + diagIndices.push_back( + b.create(loc, indOut)); + indOut++; + } + } + Value diagElt = b.create( + loc, elementType, inputMatrix, diagIndices); + b.create(loc, diagElt); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, outputType, diagonal); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -1480,4 +2134,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 277341bea874..b8754a306711 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -79,7 +79,8 @@ class ConvertAtenGatherOp : public OpConversionPattern { int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); - int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + int64_t inputRank = + adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -206,8 +207,8 @@ namespace { // // TODO: Find an optimal lowering. // current lowering is not optimal for bags of large embeddings. -// Since it traverses the output tensor multiple times. -// +// Since it traverses the output tensor multiple times. +// // class ConvertAtenEmbeddingBagPaddingIdxOp @@ -248,9 +249,9 @@ class ConvertAtenEmbeddingBagPaddingIdxOp } if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) { - return rewriter.notifyMatchFailure( - op, - "Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag."); + return rewriter.notifyMatchFailure(op, + "Unimplemented: Mean and Max mode are " + "not supported yet for EmbeddingBag."); } bool isSparse; @@ -291,28 +292,28 @@ class ConvertAtenEmbeddingBagPaddingIdxOp SmallVector indicesExpr; indicesExpr.push_back(mlir::getAffineDimExpr(1, context)); auto indicesIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - indicesExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + indicesExpr, context); SmallVector offsetsExpr; offsetsExpr.push_back(mlir::getAffineDimExpr(0, context)); auto offsetIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - offsetsExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + offsetsExpr, context); SmallVector outputExpr; outputExpr.push_back(mlir::getAffineDimExpr(0, context)); outputExpr.push_back(mlir::getAffineDimExpr(2, context)); auto outputIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - outputExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + outputExpr, context); SmallVector indexingMaps = { - indicesIndexingMap, - offsetIndexingMap, - outputIndexingMap, + indicesIndexingMap, + offsetIndexingMap, + outputIndexingMap, }; // Reduce along the indices dim @@ -326,15 +327,15 @@ class ConvertAtenEmbeddingBagPaddingIdxOp Value indicesLength; if (!discardLastOffset) { SmallVector sizes{getDimOp(rewriter, loc, offsets, 0), - embeddingDim}; + embeddingDim}; initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy); offsetsLength = getDimOp(rewriter, loc, offsets, 0); indicesLength = getDimOp(rewriter, loc, indices, 0); } else { return rewriter.notifyMatchFailure( - op, "Unimplemented: include last offset is not yet " - "supported for EmbeddingBag."); + op, "Unimplemented: include last offset is not yet " + "supported for EmbeddingBag."); } Value embeddingBagResult = @@ -351,10 +352,10 @@ class ConvertAtenEmbeddingBagPaddingIdxOp Value indexI = b.create(loc, /*value=*/0); Value indexIToInt = castIndexToInt64(b, loc, indexI); - Value one = getConstant( - b, loc, 1, - mlir::IntegerType::get(getContext(), 64, - IntegerType::Signless)); + Value one = + getConstant(b, loc, 1, + mlir::IntegerType::get( + getContext(), 64, IntegerType::Signless)); Value offsetIndexPlusOneInt = b.create(loc, indexIToInt, one); @@ -378,7 +379,7 @@ class ConvertAtenEmbeddingBagPaddingIdxOp loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); Value offsetLessThanOrEqualToIndicesIndex = b.create(loc, offsetLessThanIndicesIndex, - offsetEqualToIndicesIndex); + offsetEqualToIndicesIndex); Value indicesIndexLessThanNextOffset = b.create(loc, arith::CmpIPredicate::slt, @@ -393,19 +394,18 @@ class ConvertAtenEmbeddingBagPaddingIdxOp castIntToIndex(b, loc, indexInIndices)); indexIntoWeight.push_back( b.create(loc, /*value=*/2)); - Value weightElem = b.create( - loc, weight, indexIntoWeight); - - Value addResult = b.create(loc, weightElem, - initTensorElem); - Value select = - b.create(loc, indicesIndexWithinBounds, - addResult, initTensorElem); + Value weightElem = + b.create(loc, weight, indexIntoWeight); + + Value addResult = + b.create(loc, weightElem, initTensorElem); + Value select = b.create( + loc, indicesIndexWithinBounds, addResult, initTensorElem); b.create(loc, select); - }) - .getResult(0); + }) + .getResult(0); - // cast outputType. + // cast outputType. auto restulType0 = typeConverter->convertType(op->getResult(0).getType()); Value castedEmbeddingBagResult = rewriter.create(loc, restulType0, embeddingBagResult); @@ -439,7 +439,7 @@ class ConvertAtenEmbeddingBagPaddingIdxOp rewriter.create(loc, resultType3, indicesOut); rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult, - castedBagSizeResult, castedMaxIndices}); + castedBagSizeResult, castedMaxIndices}); return success(); } @@ -498,7 +498,8 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { resultExpr.push_back(rewriter.getAffineDimExpr(i)); } - auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); + auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}, + rewriter.getContext()); Value finalRes = rewriter @@ -552,7 +553,8 @@ static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index, // e.g. x: [2, 3] // x[[4], [6, 1]] -> x[6, 4] namespace { -class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern { +class ConvertAtenIndexTensorHackedTwinOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index b263786c3dbb..44ac95ce0429 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -29,6 +29,34 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { + +static void getZeroPoint(Value value, Value &zeropoint) { + if (auto make = value.getDefiningOp()) { + zeropoint = make.getZeroPoint(); + } +} + +static Value transposeValue(Location loc, Value value, ArrayRef perms, + PatternRewriter &rewriter) { + auto valueTy = value.getType().cast(); + auto inShape = valueTy.getShape(); + llvm::SmallVector outShape; + llvm::SmallVector dynDims; + for (size_t i = 0; i < perms.size(); ++i) { + outShape.push_back(inShape[perms[i]]); + if (ShapedType::isDynamic(inShape[perms[i]])) { + dynDims.push_back(rewriter.create(loc, value, perms[i])); + } + } + + auto outTy = RankedTensorType::get(outShape, valueTy.getElementType()); + Value empty = rewriter.create(loc, outTy, dynDims); + Value transpose = + rewriter.create(loc, value, empty, perms) + ->getResult(0); + return transpose; +} + class ConvertAtenMmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -64,11 +92,27 @@ class ConvertAtenMmOp : public OpConversionPattern { op.getSelf().getType().cast(); ValueTensorType rhsTorchType = op.getMat2().getType().cast(); + + Value lhsZeroPoint, rhsZeroPoint; + getZeroPoint(op.getSelf(), lhsZeroPoint); + getZeroPoint(op.getMat2(), rhsZeroPoint); + + if (static_cast(lhsZeroPoint) != static_cast(lhsZeroPoint)) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.mm with mixed quantization"); + } + if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) { return rewriter.notifyMatchFailure( op, "unsupported: aten.mm with different input element types"); } + bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); + if (lhsZeroPoint && isUnsigned) { + return rewriter.notifyMatchFailure( + op, "unsupported: unsigned quantized matmul not supported"); + } + Value lhsDim0 = rewriter.create(loc, lhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); @@ -89,8 +133,26 @@ class ConvertAtenMmOp : public OpConversionPattern { rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); Value matmul; - auto intType = dyn_cast(lhsTorchType.getDtype()); - if (intType && intType.isUnsigned()) { + if (lhsZeroPoint && !isUnsigned) { + lhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(lhsZeroPoint.getType()), + lhsZeroPoint); + rhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(rhsZeroPoint.getType()), + rhsZeroPoint); + lhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), lhsZeroPoint); + rhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), rhsZeroPoint); + matmul = + rewriter + .create( + loc, zeroFill.getType(), + ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill) + .getResult(0); + } else if (isUnsigned) { matmul = rewriter .create( loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill) @@ -124,7 +186,8 @@ class ConvertAtenFlipOp : public OpConversionPattern { Location loc = op->getLoc(); MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); - auto selfRank = adaptor.getSelf().getType().cast().getRank(); + auto selfRank = + adaptor.getSelf().getType().cast().getRank(); Type elementType = adaptor.getSelf().getType().cast().getElementType(); Value c1 = @@ -191,8 +254,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); + } auto lhsType = lhs.getType().cast(); auto rhsType = rhs.getType().cast(); @@ -260,7 +324,26 @@ class ConvertAtenMatmulOp : public OpConversionPattern { return success(); } - // Fourth Case: Batch-Matrix Multiplication. + // Fourth Case: Vec-Vec Multiplication. + if (lhsRank == 2 && rhsRank == 2) { + Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); + Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); + Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); + checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); + + Value zeroTensor = createZeroInitTensor( + rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); + Value matmul = + rewriter + .create(loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, matmul); + return success(); + } + + // Fifth Case: Batch-Matrix Multiplication. // TODO: Handle batch matrix multiplication when one of the matrix is unity // rank and the other has batch dimension. if (lhsRank > 1 && rhsRank > 1) { @@ -316,8 +399,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // TODO: Improve usage of static shape information. SmallVector lhsTargetShape(lhsBroadcastToShape.size(), ShapedType::kDynamic); - auto lhsBroadcastType = - RankedTensorType::get(lhsTargetShape, lhsType.getElementType()); + auto lhsBroadcastType = RankedTensorType::get( + lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType, broadcastedLhs))) { @@ -326,8 +409,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { } SmallVector rhsTargetShape(rhsBroadcastToShape.size(), ShapedType::kDynamic); - auto rhsBroadcastType = - RankedTensorType::get(rhsTargetShape, rhsType.getElementType()); + auto rhsBroadcastType = RankedTensorType::get( + rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType, broadcastedRhs))) { @@ -429,8 +512,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); Value zeroTensor = createZeroInitTensor(rewriter, loc, resultShape, elementType); - auto indexingMaps = - AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr}); + auto indexingMaps = AffineMap::inferFromExprList( + {lhsExpr, rhsExpr, outExpr}, rewriter.getContext()); iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::parallel, utils::IteratorType::reduction, @@ -474,7 +557,8 @@ class ConvertAtenBmmOp : public OpConversionPattern { RankedTensorType lhsType = lhs.getType().cast(); RankedTensorType rhsType = rhs.getType().cast(); Type newResultType = getTypeConverter()->convertType(op.getType()); - Type resultElementType = newResultType.cast().getElementType(); + Type resultElementType = + newResultType.cast().getElementType(); Type lhsElementType = lhsType.cast().getElementType(); Type rhsElementType = rhsType.cast().getElementType(); @@ -486,13 +570,15 @@ class ConvertAtenBmmOp : public OpConversionPattern { // Convert the inputs element type equivalent to the result' element type. if (lhsElementType != rhsElementType) { if (lhsElementType != resultElementType) { - // True if the lhs element type is not equal to the result' element type. - lhs = torch_to_linalg::convertTensorToElementType( - rewriter, loc, lhs, resultElementType); + // True if the lhs element type is not equal to the result' element + // type. + lhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, lhs, + resultElementType); } else { - // True if the rhs element type is not equal to the result' element type. - rhs = torch_to_linalg::convertTensorToElementType( - rewriter, loc, rhs, resultElementType); + // True if the rhs element type is not equal to the result' element + // type. + rhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, rhs, + resultElementType); } } @@ -510,7 +596,8 @@ class ConvertAtenBmmOp : public OpConversionPattern { checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, + resultElementType); Value bmm = rewriter @@ -534,21 +621,64 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { MLIRContext *context = op->getContext(); Value input = adaptor.getInput(); /* in form of N*C*H*W */ Value weight = adaptor.getWeight(); /* in form of F*C*H*W */ + Value bias = adaptor.getBias(); + auto resultTy = op.getType().cast(); + + Value inputZp, weightZp; + if (auto make = op.getInput() + .getDefiningOp()) { + input = make.getSelf(); + inputZp = make.getZeroPoint(); + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(input.getType()), input); + inputZp = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(inputZp.getType()), + inputZp); + } + + if (auto make = op.getWeight() + .getDefiningOp()) { + weight = make.getSelf(); + weightZp = make.getZeroPoint(); + + weight = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(weight.getType()), weight); + weightZp = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(weightZp.getType()), + weightZp); + } + + if (static_cast(inputZp) != static_cast(weightZp)) { + return rewriter.notifyMatchFailure( + op, "lhs and rhs of convolution must either be both int or fp"); + } + + if (inputZp && weightZp && !isa(bias.getType())) { + auto biasDTy = bias.getType().cast().getElementType(); + if (!biasDTy.isInteger(32)) { + return rewriter.notifyMatchFailure( + op, "quantized result ty should be i32 accumulator"); + } + } bool transposed = true; if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( op, "unimplemented: only constant transposed supported"); - Type elementType = - input.getType().cast().getElementType(); - if (!elementType.isa()) - return op.emitError("unimplemented: non-floating point type"); + auto inputDTy = input.getType().cast().getElementType(); + auto weightDTy = weight.getType().cast().getElementType(); + auto resultDTy = resultTy.toBuiltinTensor().getElementType(); + + if (!inputDTy.isa() || + !weightDTy.isa() || + !resultDTy.isa()) + return op.emitError("unimplemented: non-fp not-int type"); size_t inRank = input.getType().cast().getRank(); - size_t numSpacialDims = inRank - 2; - if (numSpacialDims != 2) + size_t numSpatialDims = inRank - 2; + if (numSpatialDims < 1 || numSpatialDims > 3) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D convolution currently supported"); + op, "unimplemented: only 1d-3d convolution currently supported"); Type intType = IntegerType::get(context, 64); auto castIndexToInt = [&](Value v) { @@ -573,7 +703,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int strides"); SmallVector dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); @@ -617,6 +748,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector outDims{inBatch, weightBatch}; Value paddedInput; if (transposed) { + if (!inputDTy.isa() || + !weightDTy.isa() || + !resultDTy.isa()) + return rewriter.notifyMatchFailure( + op, "transpose does not support non-fp type yet"); + Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value c1 = @@ -629,7 +766,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1); outDims[1] = weightInitDims[0]; Value weightInitTensor = - createZeroInitTensor(rewriter, loc, weightInitDims, elementType); + createZeroInitTensor(rewriter, loc, weightInitDims, weightDTy); SmallVector iteratorTypes( inRank, utils::IteratorType::parallel); SmallVector indexingMaps{ @@ -662,7 +799,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector outerSizes{inBatch, inChannels}; SmallVector innerSizes{inBatch, inChannels}; SmallVector offsets{c0, c0}; - for (size_t i = 0; i < numSpacialDims; i++) { + for (size_t i = 0; i < numSpatialDims; i++) { Value innerSize = rewriter.create(loc, inDims[i], c1); innerSize = rewriter.create( loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); @@ -686,7 +823,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Allocate padded input tensor Value initTensor = - createZeroInitTensor(rewriter, loc, outerSizes, elementType); + createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; @@ -699,7 +836,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { initTensor, offsets, insertSizes, strideIndexValues); // Calculate output dims - for (size_t i = 0; i < numSpacialDims; i++) + for (size_t i = 0; i < numSpatialDims; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], castIndexToInt(weightDims[i]), strideIntValues[i], @@ -707,36 +844,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Set stride to 1 strideInts.clear(); - strideInts.append(numSpacialDims, 1); - + strideInts.append(numSpatialDims, 1); } else { + Value pad = inputZp; + if (!pad) { + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); + } + + if (pad.getType() != inputDTy) { + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + } + // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( - op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2); + op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); // Calculate output dims - for (size_t i = 0; i < numSpacialDims; i++) + for (size_t i = 0; i < numSpatialDims; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], castIndexToInt(weightDims[i]), strideIntValues[i])); } Value initTensor = rewriter.create( - loc, getAsOpFoldResult(outDims), elementType); + loc, getAsOpFoldResult(outDims), resultDTy); - Value bias = adaptor.getBias(); Value outputTensor; if (bias.getType().isa()) { - Value c0float = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); - outputTensor = rewriter.create(loc, c0float, initTensor) - .getResult(0); + Value c0; + if (resultDTy.isa()) { + c0 = rewriter.create(loc, + FloatAttr::get(resultDTy, 0.0)); + } else if (resultDTy.isa()) { + c0 = rewriter.create(loc, + IntegerAttr::get(resultDTy, 0)); + } + outputTensor = + rewriter.create(loc, c0, initTensor).getResult(0); + } else { auto biasType = bias.getType().cast(); if (biasType.getRank() != 1) return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); - if (elementType != biasType.getElementType()) - return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); auto resultRank = initTensor.getType().cast().getRank(); SmallVector indexingMaps = { @@ -776,114 +934,208 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightSliceSizes.append(weightDims); Value conv; - if (groupSize == 1) { - // TODO: add 1D and 3D case - conv = - rewriter - .create( - loc, outputTensor.getType(), ValueRange{paddedInput, weight}, - outputTensor, stridesAttr, dilationAttr) - .getResult(0); - } else { - // Special depthwise case - auto inShape = makeShapeTorchCompatible( - input.getType().cast().getShape()); - auto weightShape = makeShapeTorchCompatible( - weight.getType().cast().getShape()); - if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && - weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { - // Collapse weight shape - SmallVector collapsedDims = {{0, 1}, {2}, {3}}; - SmallVector collapsedShape{ - (weightShape[0] == kUnknownSize ? kUnknownSize - : weightShape[0] * weightShape[1]), - weightShape[2], weightShape[3]}; - Type collapsedType = RankedTensorType::get( - makeShapeLLVMCompatible(collapsedShape), elementType); - Value collapsedWeight = rewriter.create( - loc, collapsedType, weight, collapsedDims); - + // the code so far is able to respect all numSpatialDims + // the code below this point is numSpatialDims specific and groupSize + // specific + // TODO: factor out the above code into a helper function, and then separate + // convolution into: + // - grouped 1d-3d + // - grouped 1d-3d (quantized) + // - ungrouped 1d-3d + if (groupSize == 1 && !inputZp && !weightZp) { + switch (numSpatialDims) { + case 1: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 2: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 3: conv = rewriter - .create( + .create( loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, + ValueRange{paddedInput, weight}, outputTensor, stridesAttr, dilationAttr) .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D, 2D, and 3D convolution supported"); + }; + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); + } - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, conv); - return success(); + if (groupSize == 1 && inputZp && weightZp) { + // The quantized version uses a different channel ordering so we need to + // permute the tensors in order to use the existing path. We should + // eventually directly support this channel ordering. + llvm::SmallVector inPerms, weightPerms; + inPerms.push_back(0); // N stays at the front for input. + // Then we expect the spatial dimensions + for (size_t i = 0; i < numSpatialDims; ++i) { + inPerms.push_back(i + 2); + weightPerms.push_back(i + 2); } + inPerms.push_back(1); + weightPerms.append({1, 0}); - // Grouped case, use the grouped conv linalg op - auto expandGroups = [&](Value tensor, size_t dim) { - auto inType = tensor.getType().cast(); - auto inShape = makeShapeTorchCompatible(inType.getShape()); - - SmallVector outShape; - for (auto i = 0; i < (long)inShape.size(); i++) { - if (i == 1) { - outShape.push_back(groupSize); - } - if (i == (long)dim) { - outShape.push_back(inShape[i] == kUnknownSize - ? kUnknownSize - : inShape[i] / groupSize); - } else { - outShape.push_back(inShape[i]); - } - } - - SmallVector indices; - for (auto i = 0; i <= (long)inShape.size(); i++) { - if (i == (long)dim) { - indices.push_back({i, ++i}); - continue; - } - indices.push_back({i}); - } + paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); + weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); + outputTensor = + transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); - auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); - return rewriter.create(loc, retType, tensor, - indices); + switch (numSpatialDims) { + case 2: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); + break; + case 3: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; - // expand F,C,H,W -> G,F/G,C,H,W - auto expandWeight = [&](Value tensor) { - auto inType = tensor.getType().cast(); - auto inShape = makeShapeTorchCompatible(inType.getShape()); - - SmallVector outShape{ - groupSize, (inShape[0] == kUnknownSize ? kUnknownSize - : inShape[0] / groupSize)}; - outShape.append(inShape.begin() + 1, inShape.end()); + llvm::SmallVector outPerms; + outPerms.push_back(0); + outPerms.push_back(inPerms.size() - 1); + for (size_t i = 0; i < numSpatialDims; ++i) { + outPerms.push_back(i + 1); + } + conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); - SmallVector indices{{0, 1}}; - for (auto i = 2; i <= (long)inShape.size(); i++) - indices.push_back({i}); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); + } - auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); - return rewriter.create(loc, retType, tensor, - indices); - }; + if (inputZp || weightZp) + return rewriter.notifyMatchFailure( + op, "unimplemented: quantized grouped convolutions"); - Value paddedInputExpanded = expandGroups(paddedInput, 1); - Value weightExpanded = expandWeight(weight); - auto expandOutputTensor = expandGroups(outputTensor, 1); + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D grouped convolution supported"); + + // Special depthwise case + auto inShape = makeShapeTorchCompatible( + input.getType().cast().getShape()); + auto weightShape = makeShapeTorchCompatible( + weight.getType().cast().getShape()); + if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && + weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { + // Collapse weight shape + SmallVector collapsedDims = {{0, 1}, {2}, {3}}; + SmallVector collapsedShape{ + (weightShape[0] == kUnknownSize ? kUnknownSize + : weightShape[0] * weightShape[1]), + weightShape[2], weightShape[3]}; + Type collapsedType = RankedTensorType::get( + makeShapeLLVMCompatible(collapsedShape), weightDTy); + Value collapsedWeight = rewriter.create( + loc, collapsedType, weight, collapsedDims); - // TODO: add 1D and 3D case conv = rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weightExpanded}, - expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) .getResult(0); - conv = rewriter.create( - loc, outputTensor.getType(), conv, - expandOutputTensor.getReassociationIndices()); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); } + // Grouped case, use the grouped conv linalg op + auto expandGroups = [&](Value tensor, size_t dim) { + auto inType = tensor.getType().cast(); + auto inShape = makeShapeTorchCompatible(inType.getShape()); + + SmallVector outShape; + for (auto i = 0; i < (long)inShape.size(); i++) { + if (i == 1) { + outShape.push_back(groupSize); + } + if (i == (long)dim) { + outShape.push_back(inShape[i] == kUnknownSize + ? kUnknownSize + : inShape[i] / groupSize); + } else { + outShape.push_back(inShape[i]); + } + } + + SmallVector indices; + for (auto i = 0; i <= (long)inShape.size(); i++) { + if (i == (long)dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); + return rewriter.create(loc, retType, tensor, + indices); + }; + + // expand F,C,H,W -> G,F/G,C,H,W + auto expandWeight = [&](Value tensor) { + auto inType = tensor.getType().cast(); + auto inShape = makeShapeTorchCompatible(inType.getShape()); + + SmallVector outShape{ + groupSize, + (inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / groupSize)}; + outShape.append(inShape.begin() + 1, inShape.end()); + + SmallVector indices{{0, 1}}; + for (auto i = 2; i <= (long)inShape.size(); i++) + indices.push_back({i}); + + auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); + return rewriter.create(loc, retType, tensor, + indices); + }; + + Value paddedInputExpanded = expandGroups(paddedInput, 1); + Value weightExpanded = expandWeight(weight); + auto expandOutputTensor = expandGroups(outputTensor, 1); + + // TODO: add 1D and 3D case + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + + conv = rewriter.create( + loc, outputTensor.getType(), conv, + expandOutputTensor.getReassociationIndices()); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 87419f0935ab..e795d2ea9fb8 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -72,36 +72,17 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, return success(); } -// Creates a pooling operation based on the type specified by `OpTy` and -// arguments passed. -template -static LogicalResult createPoolingOp( - Operation *op, ConversionPatternRewriter &rewriter, Value self, - bool supportNonFPInput, bool ceilMode, int64_t dimensionality, - SmallVectorImpl &kernelSizeIntValues, - SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, - SmallVectorImpl &dilationInts, Attribute initValueAttr, - SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { - Location loc = op->getLoc(); +static Value +computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter, + Value self, int64_t dimensionality, bool ceilMode, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &outTensorShape, Value initValue) { Type elementType = self.getType().cast().getElementType(); - if (!elementType.isa() && !supportNonFPInput) - return op->emitError("unimplemented: non-floating point type"); - - SmallVector lowPaddingIncludingNC = {0, 0}; - lowPaddingIncludingNC.append(paddingInts); - SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; - - if (ceilMode) { - for (int64_t i = 0; i < dimensionality; ++i) { - highPaddingIncludingNC[i + 2] += strideInts[i]; - } - } + Location loc = op->getLoc(); - Value initValue = rewriter.create(loc, cast(initValueAttr)); - paddedInput = torch_to_linalg::getPaddedTensor( - op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, - initValue); - Value N = getDimOp(rewriter, loc, self, 0); Value C = getDimOp(rewriter, loc, self, 1); @@ -123,8 +104,54 @@ static LogicalResult createPoolingOp( // Create output tensor initialized with smallest floating point value. outTensorShape.insert(outTensorShape.begin(), {N, C}); - Value outTensorInitialized = - createInitTensor(rewriter, loc, outTensorShape, elementType, initValue); + return createInitTensor(rewriter, loc, outTensorShape, elementType, + initValue); +} + +static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter, + Value self, bool ceilMode, int64_t dimensionality, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + Value initValue) { + SmallVector lowPaddingIncludingNC = {0, 0}; + lowPaddingIncludingNC.append(paddingInts); + SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; + + if (ceilMode) { + for (int64_t i = 0; i < dimensionality; ++i) { + highPaddingIncludingNC[i + 2] += strideInts[i]; + } + } + + return torch_to_linalg::getPaddedTensor(op, rewriter, self, + lowPaddingIncludingNC, + highPaddingIncludingNC, initValue); +} + +// Creates a pooling operation based on the type specified by `OpTy` and +// arguments passed. +template +static LogicalResult createPoolingOp( + Operation *op, ConversionPatternRewriter &rewriter, Value self, + bool supportNonFPInput, bool ceilMode, int64_t dimensionality, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, Attribute initValueAttr, + SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { + Location loc = op->getLoc(); + Type elementType = self.getType().cast().getElementType(); + if (!elementType.isa() && !supportNonFPInput) + return op->emitError("unimplemented: non-floating point type"); + + Value initValue = + rewriter.create(loc, cast(initValueAttr)); + + paddedInput = padInputTensor(op, rewriter, self, ceilMode, dimensionality, + strideInts, paddingInts, initValue); + + auto outTensorInitialized = computeOutputTensor( + op, rewriter, self, dimensionality, ceilMode, strideInts, paddingInts, + dilationInts, kernelSizeIntValues, outTensorShape, initValue); auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); @@ -137,57 +164,179 @@ static LogicalResult createPoolingOp( ValueRange{paddedInput, windowTensor}, outTensorInitialized, stridesAttr, dilationAttr) .getResult(0); - return success(); } - namespace { -class ConvertAtenMaxPool2dOp : public OpConversionPattern { + +template struct DimensionTraits {}; + +template <> struct DimensionTraits { + static constexpr int64_t Dim = 2; + // unused const variable warning suppression: + static_assert(Dim == Dim); +}; + +template <> struct DimensionTraits { + static constexpr int64_t Dim = 3; + // unused const variable warning suppression: + static_assert(Dim == Dim); +}; + +template +class ConvertAtenMaxPoolOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static const int64_t Dim = DimensionTraits::Dim; + + LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op, + typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, + bool ceilMode) const { + SmallVector outTensorShape; + Value self = adaptor.getSelf(); + Type elementType = self.getType().cast().getElementType(); + TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + Value initValue = + rewriter.create(op->getLoc(), smallestFPValueAttr); + + Value paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, + strideInts, paddingInts, initValue); + + auto outTensorInitialized = computeOutputTensor( + op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts, + kernelSizeIntValues, outTensorShape, initValue); + + auto shape = + castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues); + Value windowTensor = rewriter.create( + op->getLoc(), getAsOpFoldResult(shape), elementType); + + MLIRContext *context = rewriter.getContext(); + + auto mapInput = mlir::AffineMap::get( + 8, 0, + { + rewriter.getAffineDimExpr(0), // n + rewriter.getAffineDimExpr(1), // c + // dim_d * stride_d + kernal_d * dilation_d + rewriter.getAffineDimExpr(2) * + getAffineConstantExpr(strideInts[0], context) + + rewriter.getAffineDimExpr(5) * + getAffineConstantExpr(dilationInts[0], context), + // dim_h * stride_h + kernal_h * dilation_h + rewriter.getAffineDimExpr(3) * + getAffineConstantExpr(strideInts[1], context) + + rewriter.getAffineDimExpr(6) * + getAffineConstantExpr(dilationInts[1], context), + // dim_w * stride_w + kernal_w * dilation_w + rewriter.getAffineDimExpr(4) * + getAffineConstantExpr(strideInts[2], context) + + rewriter.getAffineDimExpr(7) * + getAffineConstantExpr(dilationInts[2], context), + }, + context); + auto mapKernel = + mlir::AffineMap::get(8, 0, + { + rewriter.getAffineDimExpr(5), // kd + rewriter.getAffineDimExpr(6), // kh + rewriter.getAffineDimExpr(7) // kw + }, + context); + auto mapOutput = mlir::AffineMap::get( + 8, 0, + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1), + rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(3), + rewriter.getAffineDimExpr(4)}, + context); + auto iteratorTypes = + SmallVector(5, utils::IteratorType::parallel); + iteratorTypes.append(3, utils::IteratorType::reduction); + SmallVector indexingMaps = {mapInput, mapKernel, mapOutput}; + Value poolingOp = + rewriter + .create( + op->getLoc(), + /* result types */ outTensorInitialized.getType(), + /* operands */ ValueRange({paddedInput, windowTensor}), + /* outputs */ outTensorInitialized, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value currentVal = args[0], accMaxValue = args[2]; + Value max_result = + b.create(loc, currentVal, accMaxValue); + ; + b.create(loc, max_result); + }) + .getResult(0); + Type newResultType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, poolingOp); + return success(); + } + public: - using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenMaxPool2dOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); int64_t selfRank = self.getType().cast().getRank(); - // TODO: Add support for 3D inputs. - if (selfRank == 3) + + if (selfRank != Dim + 2) return rewriter.notifyMatchFailure( - op, "unimplemented: only support 4D input"); + op, "unimplemented: Does not support inputs with rank"); bool ceilMode; - SmallVector kernelSizeIntValues; - SmallVector strideInts, paddingInts, dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts; + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); - if (failed(checkAndGetPoolingParameters( - op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, - strideInts, paddingInts))) + + if (failed(checkAndGetPoolingParameters(op, rewriter, typeConverter, + ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); Type elementType = self.getType().cast().getElementType(); - TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf(elementType.cast().getFloatSemantics(), - /*Negative=*/true)); - SmallVector outTensorShape; - // `maxpool2d` contains the result of maxpool2d operation over the input. - Value maxPool2d, paddedInput; - if (failed(createPoolingOp( - op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, - /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts, - dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, - maxPool2d))) - return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); - return success(); + + if constexpr (Dim == 2) { + SmallVector outTensorShape; + // `maxpool2d` contains the result of maxpool2d operation over the input. + Value maxPool2d, paddedInput; + TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf( + elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + if (failed(createPoolingOp( + op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, + /*dimensionality=*/2, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, + paddedInput, maxPool2d))) + return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); + Type newResultType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); + return success(); + } else { + return createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues, + strideInts, paddingInts, dilationInts, + ceilMode); + } } }; } // namespace @@ -241,7 +390,8 @@ class ConvertAtenMaxPool2dWithIndicesOp bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); if (failed(checkAndGetPoolingParameters( @@ -292,8 +442,8 @@ class ConvertAtenMaxPool2dWithIndicesOp // Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH, // and kW, respectively, as described in the algorithm above. - SmallVector indexingMaps = - AffineMap::inferFromExprList({inputExprs, kernelExprs, outputExprs}); + SmallVector indexingMaps = AffineMap::inferFromExprList( + {inputExprs, kernelExprs, outputExprs}, rewriter.getContext()); SmallVector iteratorTypes( 4, utils::IteratorType::parallel); iteratorTypes.push_back(utils::IteratorType::reduction); @@ -372,7 +522,6 @@ class ConvertAtenMaxPool2dWithIndicesOp }; } // namespace - namespace { template class ConvertAtenAvgPoolOp : public OpConversionPattern { @@ -383,7 +532,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - + Location loc = op->getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); @@ -397,9 +546,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts(Dim, 1); - if (failed(checkAndGetPoolingParameters( - op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, - strideInts, paddingInts))) + if (failed(checkAndGetPoolingParameters(op, rewriter, typeConverter, + ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // TODO: Add support for count_include_pad equal to `False`. @@ -408,27 +557,31 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { m_TorchConstantBool(&countIncludePad))) return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); - if (!countIncludePad) { + + // If the padding is zero then there is no padding to include. + if (!countIncludePad && + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { return rewriter.notifyMatchFailure( op, "unimplemented: count_include_pad is expected to be true"); } // `sumPool` contains the result of sumpool operation over the input. Value sumPool, paddedInput; - SmallVector outTensorShape; + SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, - /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts, - dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, - paddedInput, sumPool))) + /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), + outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); Value divisor; if constexpr (std::is_same()) { Value kHtimeskW = rewriter.create( loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); - divisor = op.getDivisorOverride().getType().template isa() - ? kHtimeskW - : adaptor.getDivisorOverride(); + divisor = + op.getDivisorOverride().getType().template isa() + ? kHtimeskW + : adaptor.getDivisorOverride(); } else { divisor = kernelSizeIntValues[0]; } @@ -436,9 +589,10 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { Value outputTensor = rewriter.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); - SmallVector indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2)); + SmallVector indexingMapsAvg( + 2, rewriter.getMultiDimIdentityMap(Dim + 2)); SmallVector iteratorTypesAvg( - Dim+2, utils::IteratorType::parallel); + Dim + 2, utils::IteratorType::parallel); Value avgPool = rewriter .create( @@ -459,20 +613,411 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { return success(); } }; -} +} // namespace + +/* +This section is for lowering adaptive pooling ops, which cannot generally be +decomposed into typical pooling ops. Given an input tensor of rank (N,C,Hin) and +an output spatial size Hout, an element of the output tensor at position (n, c, +h) is computed as follows. + 1. compute st(h) = (h*Hin)//Hout + 2. compute en(h) = 1 + ((h+1)*Hin - 1)//Hout + 3. apply the operation (max or avg) over input[n, c, st(h):en(h)] +This is problematic for linalg ops for a few reasons: + 1. The access to the input tensor is not constantly strided + 2. The size of the window itself is not contant: en(h) - st(h) can vary with +h! Although it is a bit like using a hammer to paint, our workaround is to use +tensor.extract to access the elements of the input tensor inside our linalg +generic op's payload. + +Current TODO's: + 1. gather most of the boilerplate out of this op and make it into an +adaptive pooling helper function. + 2. figure out what to do with the conflicting decompositions in +DecomposeComplexOps.cpp + 3. Implement more efficient passes for when the kernel-size, input spatial +dims, and output spatial dims are constant. +*/ + +namespace { +class ConvertAtenAdaptiveAvgPool1dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenAdaptiveAvgPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + + // get rank of input (same as rank of output) + int64_t rank = + adaptor.getSelf().getType().cast().getRank(); + // input operand should be NCH (i.e. rank 3) + if (rank != 3) { + return rewriter.notifyMatchFailure(op, "only supports input type NCH"); + } + + // input tensor and output shape + Value input = adaptor.getSelf(); + Value outputShape = op.getOutputSize(); + SmallVector outShapeVector; + getListConstructElements(outputShape, outShapeVector); + outShapeVector = + getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector); + Value hIn = getDimOp(rewriter, loc, input, 2); + Value hOut = outShapeVector[0]; + Value hOutIndex = castIntToIndex(rewriter, loc, hOut); + RankedTensorType inputType = input.getType().cast(); + RankedTensorType outputType = + typeConverter->convertType(op.getResult().getType()) + .cast(); + + // get elementType of input tensor + Type elementType = inputType.getElementType(); + + // make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut + Type boolType = rewriter.getI1Type(); + Value kIter; + Value constantOne = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value hInPlusOne = rewriter.create(loc, hIn, constantOne); + Value kMaxMinusOne = + rewriter.create(loc, hInPlusOne, hOutIndex); + Value kMax = rewriter.create(loc, constantOne, kMaxMinusOne); + kIter = rewriter.create( + loc, getAsOpFoldResult(ValueRange({kMax})), boolType); + + // need to buffer input, else there will possibly be an out of bounds access + // later buffVal = 0 for avg pooling and -inf for max pooling + Value buffVal = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 0)); + SmallVector lowPadding = {0, 0, 0}; + SmallVector highPadding = {0, 0, 1}; + Value buffInput = torch_to_linalg::getPaddedTensor( + op, rewriter, input, lowPadding, highPadding, buffVal); + + // make a list of outputSizes + SmallVector outputSizes; + for (unsigned i = 0; i < rank - 1; i++) { + outputSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + outputSizes.push_back(hOutIndex); + + // initialize a kernel size tensor (only for avg pooling) + Value kSizeTensor = rewriter.create( + loc, getAsOpFoldResult(ValueRange({hOutIndex})), elementType); + + // initialize an output tensor + Value initOutput = + createInitTensor(rewriter, loc, outputSizes, elementType, buffVal); + + // setup indexing maps and iterator types for linalg generic op + // for kIter (d0,d1,d2,d3) -> (d3) + // for output (d0,d1,d2,d3) -> (d0,d1,d2) + // for kSizeTensor (d0,d1,d2,d3) -> (d2) + SmallVector kIterExprs, outputExprs, kSizeTensorExprs; + for (unsigned i = 0; i < 3; i++) { + outputExprs.push_back(rewriter.getAffineDimExpr(i)); + } + kSizeTensorExprs.push_back(rewriter.getAffineDimExpr(2)); + kIterExprs.push_back(rewriter.getAffineDimExpr(3)); + SmallVector indexingMaps = AffineMap::inferFromExprList( + {kIterExprs, outputExprs, kSizeTensorExprs}, rewriter.getContext()); + SmallVector iteratorTypes( + 3, utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::reduction); + + Value indexOne = rewriter.create(loc, 1); + auto sumPool = rewriter.create( + loc, /*resultTensorTypes=*/ + TypeRange({initOutput.getType(), kSizeTensor.getType()}), + /*inputs=*/ValueRange({kIter}), + /*outputs=*/ValueRange({initOutput, kSizeTensor}), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value res = args[1]; + Value ind0 = b.create(loc, 0); + Value ind1 = b.create(loc, 1); + Value ind2 = b.create(loc, 2); + Value ind3 = b.create(loc, 3); + // compute start and end indices + // st = s1( s0(ind2 * Hin) // Hout ) + Value s0 = b.create(loc, ind2, hIn); + Value s1 = b.create(loc, s0, hOutIndex); + // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) + Value e0 = b.create(loc, ind2, indexOne); + Value e1 = b.create(loc, e0, hIn); + Value e2 = b.create(loc, e1, indexOne); + Value e3 = b.create(loc, e2, hOutIndex); + Value e4 = b.create(loc, indexOne, e3); + // get input element @ st + ind3: + Value wIndex = b.create(loc, s1, ind3); + Value inElt = b.create( + loc, elementType, buffInput, ValueRange({ind0, ind1, wIndex})); + // check if we extracted at windex < end index + Value cond = + b.create(loc, arith::CmpIPredicate(6), wIndex, e4); + // if inElt is in bounds, include it in the computation + // else, use buffVal = 0 (for max pool use -infinity) + Value out1 = b.create(loc, cond, inElt, buffVal); + // compute Kernel size: we store this to kwTensor + Value kSize = b.create(loc, e4, s1); + Value kSizeInt = castIndexToInt64(b, loc, kSize); + Value kSizeF = b.create(loc, elementType, kSizeInt); + // accumulate out2 to res = args[1] + Value out2 = b.create(loc, res, out1); + b.create(loc, ValueRange({out2, kSizeF})); + }); + + // make a linalg generic to divide each element by the corresponding + // Kernel Width. This step is only necessary for avg pooling. + SmallVector indexingMaps1 = AffineMap::inferFromExprList( + {kSizeTensorExprs, outputExprs}, rewriter.getContext()); + SmallVector iteratorTypes1( + 3, utils::IteratorType::parallel); + auto output = rewriter.create( + loc, /*resultTensorTypes=*/initOutput.getType(), + /*inputs=*/sumPool.getResultTensors()[1], + /*outputs=*/sumPool.getResultTensors()[0], + /*indexingMaps=*/indexingMaps1, + /*iteratorTypes=*/iteratorTypes1, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value q = b.create(loc, args[1], args[0]); + b.create(loc, q); + }); + + rewriter.replaceOpWithNewOp(op, outputType, + output.getResultTensors()); + return success(); + } +}; +} // namespace + +// The logic for this conversion is similar to the AdaptiveAvgPool1dOp +// conversion. Before writing any more adaptive pooling conversions, the logic +// in this should be off-loaded to a helper function, since each of the adaptive +// ops are essentially the same with some minor tweaks. Instead of kSizeTensor, +// we named the additional output of the linalg generic op auxTensor. +// For max pooling, auxTensor holds the indices of max values, and for +// avg pooling, the auxTensor will be kSizeTensor, used to later divide the +// sum pool by the kernel size. +namespace { +class ConvertAtenAdaptiveMaxPool2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenAdaptiveMaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + + // get rank of input (same as rank of output) + int64_t rank = + adaptor.getSelf().getType().cast().getRank(); + // input operand should be NCHW (i.e. rank 4) + if (rank != 4) { + return rewriter.notifyMatchFailure(op, "only supports input type NCHW"); + } + + // input tensor and output shape + Value input = adaptor.getSelf(); + Value outputShape = op.getOutputSize(); + SmallVector outShapeVector; + getListConstructElements(outputShape, outShapeVector); + outShapeVector = + getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector); + SmallVector inputSpatialSizes; + for (unsigned i = 2; i < rank; i++) { + inputSpatialSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + SmallVector outShapeIndexVector; + for (auto v : outShapeVector) { + outShapeIndexVector.push_back(castIntToIndex(rewriter, loc, v)); + } + RankedTensorType inputType = input.getType().cast(); + RankedTensorType outputType = + typeConverter->convertType(op.getResult0().getType()) + .cast(); + + // get elementType of input tensor + Type elementType = inputType.getElementType(); + + // make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut + Type boolType = rewriter.getI1Type(); + SmallVector kIterSizeVector; + Value constantOne = + rewriter.create(loc, rewriter.getIndexAttr(1)); + for (int i = 0; i < rank - 2; i++) { + Value hInPlusOne = rewriter.create( + loc, inputSpatialSizes[i], constantOne); + Value kMaxMinusOne = rewriter.create( + loc, hInPlusOne, outShapeIndexVector[i]); + Value kMax = + rewriter.create(loc, constantOne, kMaxMinusOne); + kIterSizeVector.push_back(kMax); + } + Value kIter = rewriter.create( + loc, getAsOpFoldResult(kIterSizeVector), boolType); + + // need to buffer input, else there will possibly be an out of bounds access + // later buffVal = 0 for avg pooling and -inf for max pooling + auto smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + Value buffVal = rewriter.create(loc, elementType, + smallestFPValueAttr); + SmallVector lowPadding(rank, 0); + SmallVector highPadding(2, 0); + for (int i = 0; i < rank - 2; i++) { + highPadding.push_back(1); + } + Value buffInput = torch_to_linalg::getPaddedTensor( + op, rewriter, input, lowPadding, highPadding, buffVal); + // make a list of outputSizes + SmallVector outputSizes; + for (unsigned i = 0; i < 2; i++) { + outputSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + for (unsigned i = 2; i < rank; i++) { + outputSizes.push_back(outShapeIndexVector[i - 2]); + } + + // for avg pooling the auxTensor should hold kernel widths (kSizeTensor) + // for max Pooling, it should hold the indices + RankedTensorType outputType1 = + typeConverter->convertType(op.getResult1().getType()) + .cast(); + Type indicesType = outputType1.getElementType(); + Value auxTensor = rewriter.create( + loc, getAsOpFoldResult(outputSizes), indicesType); + + // initialize an output tensor + Value initOutput = + createInitTensor(rewriter, loc, outputSizes, elementType, buffVal); + + // setup indexing maps and iterator types for linalg generic op (outputShape + // (rank),kIter (rank -2)) for kIter (d0,d1,d2,d3,d4,d5) -> (d4,d5) for + // output (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) for auxTensor + // (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) (or (d2,d3) for avg pooling) + SmallVector kIterExprs, outputExprs, auxTensorExprs; + // batch + channel + output spatial dims + for (unsigned i = 0; i < rank; i++) { + outputExprs.push_back(rewriter.getAffineDimExpr(i)); + auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); + } + // kIter covers last rank-2 indices + for (unsigned i = rank; i < 2 * rank - 2; i++) { + kIterExprs.push_back(rewriter.getAffineDimExpr(i)); + } + SmallVector indexingMaps = AffineMap::inferFromExprList( + {kIterExprs, outputExprs, auxTensorExprs}, rewriter.getContext()); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + for (unsigned i = 0; i < rank - 2; i++) { + iteratorTypes.push_back(utils::IteratorType::reduction); + } + Value indexOne = rewriter.create(loc, 1); + auto maxPool = rewriter.create( + loc, /*resultTensorTypes=*/ + TypeRange({initOutput.getType(), auxTensor.getType()}), + /*inputs=*/ValueRange({kIter}), + /*outputs=*/ValueRange({initOutput, auxTensor}), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value res = args[1]; + Value maxIndex = args[2]; + SmallVector ind; + for (unsigned i = 0; i < 2 * rank - 2; i++) { + ind.push_back(b.create(loc, i)); + } + // compute start and end indices + // st = s1( s0(ind2 * Hin) // Hout ) + SmallVector starts; + SmallVector ends; + for (unsigned i = 2; i < rank; i++) { + Value s0 = + b.create(loc, ind[i], inputSpatialSizes[i - 2]); + Value s1 = b.create( + loc, s0, outShapeIndexVector[i - 2]); + starts.push_back(s1); + // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) + Value e0 = b.create(loc, ind[i], indexOne); + Value e1 = + b.create(loc, e0, inputSpatialSizes[i - 2]); + Value e2 = b.create(loc, e1, indexOne); + Value e3 = b.create( + loc, e2, outShapeIndexVector[i - 2]); + Value e4 = b.create(loc, indexOne, e3); + ends.push_back(e4); + } + SmallVector inputElementIndices; + inputElementIndices.push_back(ind[0]); + inputElementIndices.push_back(ind[1]); + for (unsigned i = 2; i < rank; i++) { + inputElementIndices.push_back( + b.create(loc, starts[i - 2], ind[rank - 2 + i])); + } + Value inElt = b.create(loc, elementType, buffInput, + inputElementIndices); + // check if we extracted at windex < end index + for (unsigned i = 0; i < rank - 2; i++) { + Value cond = + b.create(loc, arith::CmpIPredicate(6), + inputElementIndices[i + 2], ends[i]); + inElt = b.create(loc, cond, inElt, buffVal); + } + Value cond1 = b.create(loc, arith::CmpFPredicate::OGT, + inElt, res); + // index location is (ih * input_width + iw) + Value indexOut0 = b.create(loc, inputElementIndices[2], + inputSpatialSizes[1]); + Value indexOut1 = + b.create(loc, indexOut0, inputElementIndices[3]); + Value indexOut1Int = castIndexToInt64(b, loc, indexOut1); + Value indexOut2 = + b.create(loc, cond1, indexOut1Int, maxIndex); + Value out2 = b.create(loc, cond1, inElt, res); + b.create(loc, ValueRange({out2, indexOut2})); + }); + + Value maxValues = rewriter.create( + loc, outputType, maxPool.getResultTensors()[0]); + Value outputIndices = rewriter.create( + loc, outputType1, maxPool.getResultTensors()[1]); + rewriter.replaceOp(op, {maxValues, outputIndices}); + return success(); + } +}; +} // namespace void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add>( - typeConverter, context); - patterns.add>( - typeConverter, context); + patterns + .add>( + typeConverter, context); + patterns + .add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 26a2c0ea551a..35c349a6a673 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -194,7 +194,6 @@ class ConvertAtenUniformOp : public OpConversionPattern { }; } // namespace - void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 289851cd3d27..952610c5404d 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -60,18 +60,15 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - RankedTensorType valResultType = - getTypeConverter() - ->convertType(op.getResult(0).getType()) - .template cast(); - - RankedTensorType idxResultType = - this->getTypeConverter() - ->convertType(op.getResult(1).getType()) - .template cast(); + auto typec = this->getTypeConverter(); + auto valResultType = + cast(typec->convertType(op.getResult(0).getType())); + auto idxResultType = + cast(typec->convertType(op.getResult(1).getType())); RankedTensorType inputType = input.getType().template cast(); - Type idxElementType = idxResultType.getElementType(); + Type idxElementType = + getElementTypeOrSelf(typec->convertType(idxResultType)); if (!idxElementType.isa()) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires integer-like result type"); @@ -90,6 +87,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); Type inElementType = inputType.getElementType(); + bool isUnsigned = false; if (!inElementType.isa()) { if (inElementType.isa()) { auto integerTy = op.getSelf() @@ -97,26 +95,21 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { .template cast() .getDtype() .template dyn_cast(); - if (integerTy.isUnsigned()) - return rewriter.notifyMatchFailure( - op, opName + " to linalg.* requires input element type " - "to be signed in case of integer"); + isUnsigned = integerTy.isUnsigned(); } else { return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires Float or Integer " - "input element type"); + "input element type"); } } // Constant op to account for the reduction along dim. - auto c1 = rewriter.create(loc, /*value=*/1); SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { if (dim != i) { auto currentDimSize = rewriter.create(loc, input, i); resultShape.push_back(currentDimSize); - } else if (keepDim) - resultShape.push_back(c1); + } } // First fill the output buffer for the index. Value filledTensorIdx = @@ -135,40 +128,41 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { APFloat::getInf( inElementType.cast().getFloatSemantics(), /*Negative=*/isMax))); - } else { + } else if (!isUnsigned) { auto width = inElementType.cast().getWidth(); auto init = isMax ? APSInt::getSignedMinValue(width) : APSInt::getSignedMaxValue(width); fillValue = rewriter.create( loc, rewriter.getIntegerAttr(inElementType, init)); + } else if (isUnsigned) { + auto width = inElementType.cast().getWidth(); + auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width); + fillValue = rewriter.create( + loc, rewriter.getIntegerAttr(inElementType, init)); } Value filledTensorVal = - rewriter.create(loc, fillValue, initTensorVal) - .result(); + rewriter.create(loc, fillValue, initTensorVal).result(); + + SmallVector iteratorTypes( + inputType.getRank(), utils::IteratorType::parallel); + iteratorTypes[dim] = utils::IteratorType::reduction; // Create the affine expressions that will be used to // iterate over the input and output tensors. // Here we also set the type of iterator: parallel or reduction. + SmallVector exprs; - SmallVector iteratorTypes; SmallVector resultExprs; for (auto size : llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { exprs.push_back(rewriter.getAffineDimExpr(size.index())); - - if (unsigned(dim) == size.index()) { - iteratorTypes.push_back(utils::IteratorType::reduction); - // If `keepDim`, create affine map to the first element - // in the current dimension. - if (keepDim) - resultExprs.push_back(rewriter.getAffineConstantExpr(0)); - } else { - iteratorTypes.push_back(utils::IteratorType::parallel); + if (unsigned(dim) != size.index()) resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); - } } - auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}); + + auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}, + rewriter.getContext()); auto linalgOp = rewriter.create( loc, ArrayRef({filledTensorVal.getType(), filledTensorIdx.getType()}), @@ -186,7 +180,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Value resultVal, predicate; if (inElementType.isa()) { - arith::CmpFPredicate predType; + arith::CmpFPredicate predType; if (isMax) { predType = arith::CmpFPredicate::OGT; resultVal = rewriter.create( @@ -198,17 +192,29 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } predicate = rewriter.create(nestedLoc, predType, - newValue, oldValue); + newValue, oldValue); } else { arith::CmpIPredicate predType; if (isMax) { - predType = arith::CmpIPredicate::sgt; - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + predType = isUnsigned ? arith::CmpIPredicate::ugt + : arith::CmpIPredicate::sgt; + if (isUnsigned) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } else { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } } else { - predType = arith::CmpIPredicate::slt; - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + predType = isUnsigned ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt; + if (isUnsigned) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } else { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } } predicate = rewriter.create(nestedLoc, predType, newValue, oldValue); @@ -219,12 +225,58 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { nestedLoc, ValueRange({resultVal, resultIndex})); }); - // This cast is required to fix the shape in the case of keepDim=True - Value valuesCast = rewriter.create( - loc, valResultType, linalgOp.getResult(0)); - Value idxCast = rewriter.create(loc, idxResultType, - linalgOp.getResult(1)); - rewriter.replaceOp(op, {valuesCast, idxCast}); + if (!keepDim) { + Value rVal = rewriter.create(loc, valResultType, + linalgOp.getResult(0)); + Value rIdx = rewriter.create(loc, idxResultType, + linalgOp.getResult(1)); + llvm::SmallVector res{rVal, rIdx}; + rewriter.replaceOp(op, res); + return success(); + } + + llvm::SmallVector valShape(valResultType.getShape()); + llvm::SmallVector idxShape(idxResultType.getShape()); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i] = valShape[i + 1]; + idxShape[i] = idxShape[i + 1]; + } + + valShape.resize(valShape.size() - 1); + idxShape.resize(idxShape.size() - 1); + + Value rVal = rewriter.create( + loc, valResultType.clone(valShape), linalgOp.getResult(0)); + Value rIdx = rewriter.create( + loc, idxResultType.clone(idxShape), linalgOp.getResult(1)); + + SmallVector reassociation(valShape.size()); + if (reassociation.size() > 0) { + for (int i = 0; i < dim; ++i) + reassociation[i].push_back(i); + reassociation[std::max(0, dim - 1)].push_back(dim); + for (int i = dim, s = reassociation.size(); i < s; ++i) + reassociation[i].push_back(i + 1); + } + + valShape.push_back(0); + idxShape.push_back(0); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i + 1] = valShape[i]; + idxShape[i + 1] = idxShape[i]; + } + + valShape[dim] = 1; + idxShape[dim] = 1; + + Value unsqueezeVal = rewriter.create( + loc, valResultType, rVal, reassociation); + + Value unsqueezeIdx = rewriter.create( + loc, idxResultType, rIdx, reassociation); + + llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; + rewriter.replaceOp(op, unsqueezes); return success(); } }; @@ -275,9 +327,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, elementType.getIntOrFloatBitWidth()))); } - if (isa(op) || isa(op)) + if (isa(op) || isa(op) || + isa(op)) return b.create(loc, b.getZeroAttr(elementType)); + if (isa(op)) { + return b.create(loc, b.getBoolAttr(true)); + } + op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } @@ -337,6 +394,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (intType.isSigned()) return b.create(loc, self, result); } + } else if (isa(op)) { + // This creates payload for only the first of the two linalg.generic ops. + // TODO: Short-circuit operations if `p` is zero or one. + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + + // TODO: Fix this part to support complex elements. + if (elem.getType().isa()) { + op->emitError("lowering of complex input type for torch.aten.norm.Scalar " + "is currently unimplemented"); + return nullptr; + } + + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + + auto abs = b.create(loc, self); + AtenNormScalarOp::Adaptor adaptor(operands); + Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType); + auto pow = b.create(loc, abs, p); + return b.create(loc, pow, result); } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `ord` is zero or one. @@ -345,7 +422,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value self = convertScalarToDtype(b, loc, elem, resultElementType); auto abs = b.create(loc, self); AtenLinalgVectorNormOp::Adaptor adaptor(operands); - Value ord = convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); + Value ord = + convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); } else if (isa(op)) { @@ -357,6 +435,11 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, auto ord = b.create(loc, twoAttr); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); + } else if (isa(op)) { + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + return b.create(loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -423,12 +506,12 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the dimensions of the - // input tensor. + // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the + // dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); @@ -447,6 +530,9 @@ class ConvertReductionOp : public ConversionPattern { if (auto normOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); + if (auto allOp = dyn_cast(op)) + return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter); + return rewriter.notifyMatchFailure(op, "not a supported reduce op"); } @@ -471,10 +557,12 @@ class ConvertReductionOp : public ConversionPattern { return err ? Value{} : powOp; } - FailureOr createSecondReductionForVectorNormOp( - Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp, - Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo, - ConversionPatternRewriter &rewriter) const { + template + FailureOr + createSecondReductionForNormOp(Location loc, Type elemType, TOp op, + Value ordOp, Value firstReduction, + const torch_to_linalg::ReductionOpInfo &opInfo, + ConversionPatternRewriter &rewriter) const { // Cast `ord` to float so that we can readily pass it math.powf. Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType); @@ -531,10 +619,15 @@ class ConvertReductionOp : public ConversionPattern { LogicalResult validateReductionElementType(Operation *op, Type elemType, ConversionPatternRewriter &rewriter) const { - if ((isa(op) || isa(op)) && + if ((isa(op) || isa(op) || + isa(op)) && !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); + if (isa(op) && elemType.isa() && + elemType.getIntOrFloatBitWidth() == 8) + return rewriter.notifyMatchFailure(op, "uint8 is not supported"); + // No checks for all other reduction operations return success(); } @@ -571,11 +664,22 @@ class ConvertReductionOp : public ConversionPattern { return rewriter.notifyMatchFailure( op, "failed to create linalg.generic operation for reduction"); + // If this is aten.norm.Scalar op, then we need to generate another + // linalg.generic op that references the first linalg.generic op. + if (isa(op)) { + AtenNormScalarOp::Adaptor adaptor(operands); + FailureOr secondReduceOp = createSecondReductionForNormOp( + loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter); + if (failed(secondReduceOp)) + return secondReduceOp; + reduceOp = *secondReduceOp; + } + // If this is aten.linalg_vector_norm op, then we need to generate another // linalg.generic op that references the first linalg.generic op. if (auto normOp = dyn_cast(op)) { AtenLinalgVectorNormOp::Adaptor adaptor(operands); - FailureOr secondReduceOp = createSecondReductionForVectorNormOp( + FailureOr secondReduceOp = createSecondReductionForNormOp( loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter); if (failed(secondReduceOp)) return secondReduceOp; @@ -610,6 +714,8 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 434b50b034dd..385f5b435e1b 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -24,6 +24,8 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -43,42 +45,296 @@ class ConvertAtenConstantPadNdOp auto type = self.getType().cast(); int64_t rank = type.getRank(); - // Pattern match against the op's original operands, because otherwise we - // will get the lowered version of the operands which is harder to pattern - // match. - SmallVector padInts; - if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) - return rewriter.notifyMatchFailure( - op, "only support constant int pad ranges"); - uint64_t padRank = padInts.size() / 2; - if (padRank * 2 != padInts.size()) + auto primList = op.getPad().getDefiningOp(); + if (!primList) { + return rewriter.notifyMatchFailure(op, "unable to get pad values"); + } + + SmallVector padVals(primList.getOperands()); + + uint64_t padRank = padVals.size() / 2; + if (padRank * 2 != padVals.size()) return rewriter.notifyMatchFailure(op, "pad range size is not even"); if (rank < 0 || padRank > (uint64_t)rank) return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); // Initialize low/high paddings with the dims that should not be padded. - SmallVector lowPadding(/*Size=*/rank - padRank, /*Value=*/0); - SmallVector highPadding(/*Size=*/rank - padRank, /*Value=*/0); + int64_t noPad = rank - padRank; + Attribute zero = rewriter.getIndexAttr(0); + SmallVector staticLow(noPad, 0); + SmallVector staticHigh(noPad, 0); + SmallVector lowPad(noPad, zero); + SmallVector highPad(noPad, zero); + + auto tc = getTypeConverter(); + // Add the requested padding - note op.pad() is highest dim first ordered // pairs of low,high. for (uint64_t i = padRank; i > 0; --i) { - lowPadding.push_back(padInts[i * 2 - 2]); - highPadding.push_back(padInts[i * 2 - 1]); + int64_t lowi, highi; + Value lowv = padVals[i * 2 - 2]; + Value highv = padVals[i * 2 - 1]; + if (!matchPattern(lowv, m_TorchConstantInt(&lowi))) { + Type cty = tc->convertType(lowv.getType()); + lowv = tc->materializeTargetConversion(rewriter, loc, cty, lowv); + lowv = rewriter.create(loc, rewriter.getIndexType(), + lowv); + lowPad.push_back(lowv); + staticLow.push_back(ShapedType::kDynamic); + } else { + lowPad.push_back(rewriter.getIndexAttr(lowi)); + staticLow.push_back(lowi); + } + + if (!matchPattern(highv, m_TorchConstantInt(&highi))) { + Type cty = tc->convertType(highv.getType()); + highv = tc->materializeTargetConversion(rewriter, loc, cty, highv); + highv = rewriter.create( + loc, rewriter.getIndexType(), highv); + highPad.push_back(highv); + staticHigh.push_back(ShapedType::kDynamic); + } else { + highPad.push_back(rewriter.getIndexAttr(highi)); + staticHigh.push_back(highi); + } } Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); Value castedValue = convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); - Value paddedInput = torch_to_linalg::getPaddedTensor( - op, rewriter, self, lowPadding, highPadding, castedValue); + Type padType = tensor::PadOp::inferResultType( + self.getType().cast(), staticLow, staticHigh); + Value paddedInput = rewriter.create( + loc, padType, self, lowPad, highPad, castedValue); rewriter.replaceOpWithNewOp(op, newResultType, paddedInput); return success(); } }; } // namespace +namespace { + +// Lower aten.replication_pad2d operator into a sequence of +// tensor.extract_slice and tensor.concat operations. + +class ConvertAtenReplicationPad2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (inputRank < 0 || padRank > (uint64_t)inputRank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + int64_t hDim = numDims - 1; + int64_t vDim = numDims - 2; + Value hDimSize = inputShape[hDim]; + Value vDimSize = inputShape[vDim]; + + enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; + enum tileVLoc { + TOP = 0, + VCENTER = 2, + BOTTOM = 1, + }; + // vTile denotes the vertical size of the tile + // hTile denotes the horizontal size of the tile + // The padding results are composed of following tiles: + // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] + // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], + // vTile[VCENTER]hTile[RIGHT] vTile[BOTTOM]hTile[LEFT], + // vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] + // vTile[VCENTER]hTile[HCENTER] is the original input tensor + Type indexType = rewriter.getIndexType(); + Value vTile[3]; + Value hTile[3]; + vTile[VCENTER] = vDimSize; + hTile[HCENTER] = hDimSize; + vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); + vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); + hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); + hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); + + bool hasLeftPadding = false; + bool hasRightPadding = false; + bool hasTopPadding = false; + bool hasBottomPadding = false; + + for (auto i : {TOP, VCENTER, BOTTOM}) { + for (auto j : {LEFT, HCENTER, RIGHT}) { + auto constVtile{ + mlir::dyn_cast(vTile[i].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; + + auto constHtile{ + mlir::dyn_cast(hTile[j].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; + auto vSize = constVtile.getInt(); + auto hSize = constHtile.getInt(); + + if ((i == TOP) && (vSize > 0)) + hasTopPadding = true; + if ((i == BOTTOM) && (vSize > 0)) + hasBottomPadding = true; + if ((j == LEFT) && (hSize > 0)) + hasLeftPadding = true; + if ((j == RIGHT) && (hSize > 0)) + hasRightPadding = true; + } + } + + auto createSub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + // Extract left and right pad tiles. + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + Value hDimSizeMinusOne = createSub(hDimSize, one); + Value vDimSizeMinusOne = createSub(vDimSize, one); + SmallVector allOneStrides(numDims, one); + + SmallVector extractOffsetsLT(numDims, zero); + extractOffsetsLT[hDim] = zero; + extractOffsetsLT[vDim] = zero; + SmallVector extractShapeLR(numDims, one); + extractShapeLR[hDim] = one; + extractShapeLR[vDim] = vDimSize; + + SmallVector extractOffsetsRight(numDims, zero); + extractOffsetsRight[hDim] = hDimSizeMinusOne; + extractOffsetsRight[vDim] = zero; + + SmallVector extractOffsetsBottom(numDims, zero); + extractOffsetsBottom[hDim] = zero; + extractOffsetsBottom[vDim] = vDimSizeMinusOne; + + SmallVector extractShapeTB(numDims, one); + extractShapeTB[hDim] = hDimSize; + extractShapeTB[vDim] = one; + + SmallVector tensorsLeft; + SmallVector tensorsRight; + SmallVector tensorsCenter; + Value centerTile; + SmallVector tensorsRes; + + if (hasLeftPadding) { + Value vCenterLeftSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); + Value vLeftSlice = vCenterLeftSlice; + if (hasTopPadding) { + Value topLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, zero}); + // pad vCenterLeftSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vLeftSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); + } + if (hasBottomPadding) { + Value bottomLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); + + // pad vLeftSlice at the bottom + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vLeftSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); + } + for (auto i = 0; i < padInts[0]; ++i) { + tensorsLeft.push_back(vLeftSlice); + } + Value leftPadTile = + rewriter.create(loc, 3, tensorsLeft); + tensorsRes.push_back(leftPadTile); + } + if (hasTopPadding) { + Value topHcenterSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[2]; ++i) { + tensorsCenter.push_back(topHcenterSlice); + } + } + tensorsCenter.push_back(input); + if (hasBottomPadding) { + Value bottomHcenterSlice = rewriter.create( + loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[3]; ++i) { + tensorsCenter.push_back(bottomHcenterSlice); + } + } + centerTile = rewriter.create(loc, 2, tensorsCenter); + tensorsRes.push_back(centerTile); + + if (hasRightPadding) { + Value vCenterRightSlice = rewriter.create( + loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); + Value vRightSlice = vCenterRightSlice; + if (hasTopPadding) { + Value topRightValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); + + // pad vCenterRightSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vRightSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); + } + if (hasBottomPadding) { + Value bottomRightValue = rewriter.create( + loc, input, + ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); + + // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vRightSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vRightSlice, lowPadding, highPadding, + bottomRightValue); + } + for (auto i = 0; i < padInts[1]; ++i) { + tensorsRight.push_back(vRightSlice); + } + Value rightPadTile = + rewriter.create(loc, 3, tensorsRight); + tensorsRes.push_back(rightPadTile); + } + Value resTensor = rewriter.create(loc, 3, tensorsRes); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, resTensor); + return success(); + } +}; +} // namespace + namespace { // Converts constant tensor allocation like ops. template @@ -140,8 +396,8 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { // Create an uninitialized tensor of `resultSize` shape and fill it with // value `fillVal`. Value constVal = getConstant(rewriter, loc, fillVal, resultElementType); - Value outputTensor = - createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal); + Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex, + resultElementType, constVal); rewriter.replaceOpWithNewOp(op, resultType, outputTensor); return success(); } @@ -176,7 +432,8 @@ class ConvertAtenEmptyMemoryFormatOp // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!op.getMemoryFormat().getType().isa()) { int64_t memoryFormat; - if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) + if (!matchPattern(op.getMemoryFormat(), + m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); @@ -287,7 +544,8 @@ class ConvertAtenArangeStartStepOp typeConverter->convertType(op->getResult(0).getType()) .cast(); Type dtype = resultType.getElementType(); - Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); + Value start = + convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); Value end = convertScalarToDtype(rewriter, loc, adaptor.getEnd(), dtype); Value step = convertScalarToDtype(rewriter, loc, adaptor.getStep(), dtype); @@ -348,6 +606,8 @@ void mlir::torch::torch_to_linalg:: RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index a1e8e5fb72d9..58e6daa9bca8 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -191,12 +191,14 @@ class ConvertPrimNumToTensorScalarOp } // namespace namespace { -class ConvertAtenScalarImplicitOp - : public OpConversionPattern { +// Converts a tensor with one element to a scalar value. +template +class ConvertAtenImplicitLikeOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, + typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getA()); return success(); @@ -224,6 +226,12 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - target.addIllegalOp(); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2cc37a88313a..86bc4578178f 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { } template -static Value -createCalculationForMathOpWithDtypeConversion(OpBuilder &b, - const TypeConverter *converter, - Value payloadArg, Operation *op) { - Type dtype = converter->convertType(op->getResult(0).getType()) - .template cast() - .getElementType(); +static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, + Value payloadArg, Operation *op) { + Type inTTy = cast(op->getOperand(0).getType()).getDtype(); + Type outTTy = cast(op->getResult(0).getType()).getDtype(); + Type outTy = + cast(converter->convertType(op->getResult(0).getType())) + .getElementType(); + Type computeTy = outTy; + if (isa(computeTy)) + computeTy = b.getF32Type(); Location loc = op->getLoc(); - Value arg = convertScalarToDtype(b, loc, payloadArg, dtype); - return b.create(loc, arg); + Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy); + auto newOp = b.create(loc, arg); + return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } template @@ -216,65 +220,71 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; @@ -412,16 +422,37 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } - if (isa(op)) + if (isa(op)) { + if (payloadArgs[0].getType().isa()) + return b.create(loc, payloadArgs[0]); return b.create(loc, payloadArgs[0]); + } + if (isa(op)) { + Value abs = b.create(loc, payloadArgs[0]); + Value infinity = b.create( + loc, + b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); + return createEqual(b, loc, abs.getType(), abs, infinity); + } if (isa(op)) { - auto negate = createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + Type inTTy = cast(op->getOperand(0).getType()).getDtype(); + Type outTTy = cast(op->getResult(0).getType()).getDtype(); + Type outTy = cast( + converter->convertType(op->getResult(0).getType())) + .getElementType(); + Type computeTy = outTy; + if (isa(computeTy)) + computeTy = b.getF32Type(); + + Value arg = payloadArgs[0]; + arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy); + auto negate = b.create(loc, arg); auto one = b.create(loc, FloatAttr::get(negate.getType(), 1)); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); - return b.create(loc, one, added); + auto div = b.create(loc, one, added); + return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy); } if (auto relu = dyn_cast(op)) { if (!relu.getType() @@ -480,11 +511,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } // TODO: Take approximation into account. std::string approximate; - if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate)) || - approximate != "none") + if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate))) { + gelu.emitError( + "unimplemented: expected approximate to be a constant str"); return nullptr; - Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]); - return b.create(loc, payloadArgs[0], cdf); + } + if (approximate == "none") { + Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]); + return b.create(loc, payloadArgs[0], multiplier); + } + if (approximate == "tanh") { + // GELU(x)=0.5∗x∗(1+Tanh((2/Ï€)^1/2 * (x+0.044715∗x^3))) + // Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html + Value cstThree = b.create( + loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3)); + Value xCube = b.create(loc, payloadArgs[0], cstThree); + Type elementType = payloadArgs[0].getType(); + Value cstAlpha = b.create( + loc, FloatAttr::get(elementType, 0.044715)); + Value xCubeMulAlpha = b.create(loc, xCube, cstAlpha); + Value xPlusXCubeMulAlpha = + b.create(loc, payloadArgs[0], xCubeMulAlpha); + Value cstBeta = b.create( + loc, FloatAttr::get(elementType, 0.7977240352174656)); + Value betaMulX = + b.create(loc, cstBeta, xPlusXCubeMulAlpha); + Value tanh = b.create(loc, betaMulX); + Value cstOne = + b.create(loc, FloatAttr::get(elementType, 1.0)); + Value onePlusTanh = b.create(loc, cstOne, tanh); + Value cstHalf = + b.create(loc, FloatAttr::get(elementType, 0.5)); + Value multiplier = b.create(loc, cstHalf, onePlusTanh); + return b.create(loc, payloadArgs[0], multiplier); + } + gelu.emitError("unimplemented: approximate value should be none or tanh"); + return nullptr; } if (auto geluBackward = dyn_cast(op)) { if (!geluBackward.getType() @@ -548,12 +610,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); + Type resultElementType = add.getType().cast().getDtype(); Type dtype = converter->convertType(add.getType()) .cast() .getElementType(); - Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); @@ -576,7 +645,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*dstOriginalDtype=*/resultElementType); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, /*srcOriginalDtype=*/std::nullopt, - /*dstOriginalDtype=*/resultElementType); + /*dstOriginalDtype=*/resultElementType, + /*originalScalar=*/sub.getAlpha()); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); @@ -591,7 +661,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); - Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + Value alpha = convertScalarToDtype( + b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), + /*dstOriginalDtype=*/dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); @@ -685,13 +757,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(div.getType()) .cast() .getElementType(); - if (!dtype.isa()) { - div.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + if (dtype.isa()) + return b.create(loc, lhs, rhs); + else if (dtype.isa()) { + if (dtype.isUnsignedInteger()) + return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs); + } + div.emitError("unimplemented: non-floating point and non-integer dtype"); + return nullptr; } if (auto divTensorMode = dyn_cast(op)) { AtenDivTensorModeOp::Adaptor adaptor(operands); @@ -989,13 +1065,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { - Type dtype = converter->convertType(clamp.getType()) - .cast() - .getElementType(); - if (!dtype.isa()) { - clamp.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); @@ -1004,19 +1073,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp( clamp.emitError("unimplemented: runtime optional type"); return nullptr; } - auto result = payloadArgs[0]; - if (!min.getType().isa()) { - auto minPromoted = convertScalarToDtype(b, loc, min, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::ULT, - result, minPromoted); - result = b.create(loc, pred, minPromoted, result); + + Type dtype = converter->convertType(clamp.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + clamp.emitError("unimplement type for clamp"); + return nullptr; } - if (!max.getType().isa()) { - auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::UGT, - result, maxPromoted); - result = b.create(loc, pred, maxPromoted, result); + + Type dstOriginalDtype = clamp.getType().cast().getDtype(); + bool isUnsigned = isa(dstOriginalDtype); + if (auto intTy = dstOriginalDtype.dyn_cast()) { + isUnsigned = intTy.isUnsigned(); } + auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value { + clamp = convertScalarToDtype(b, loc, clamp, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/dstOriginalDtype); + + Value pred; + if (dtype.isa()) { + auto cmp = + getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; + pred = b.create(loc, cmp, input, clamp); + } else if (dtype.isa()) { + auto cmp = + isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; + if (getMax) + cmp = arith::invertPredicate(cmp); + pred = b.create(loc, cmp, input, clamp); + } + return b.create(loc, pred, clamp, input); + }; + + auto result = payloadArgs[0]; + if (!min.getType().isa()) + result = cmpSelect(result, min, /*getMax=*/false); + if (!max.getType().isa()) + result = cmpSelect(result, max, /*getMax=*/true); return result; } if (auto clampTensor = dyn_cast(op)) { @@ -1077,7 +1172,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); - Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + Value alpha = convertScalarToDtype( + b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), + /*dstOriginalDtype=*/dtype); if (dtype.isa()) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); @@ -1158,6 +1255,49 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } + if (auto remTensor = dyn_cast(op)) { + Type newResultType = converter->convertType(remTensor.getType()) + .cast() + .getElementType(); + + Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); + Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); + Value result; + + if (newResultType.isa()) { + result = b.create(loc, self, other); + } else if (newResultType.isa()) { + result = b.create(loc, self, other); + } else { + remTensor.emitError( + "Unsupported type encountered for AtenRemainderTensorOp."); + } + + return result; + } + if (auto fmod = dyn_cast(op)) { + Type newResultType = converter->convertType(fmod.getType()) + .cast() + .getElementType(); + + Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); + Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); + Value result; + + if (newResultType.isa()) { + Value n = b.create(loc, self, other); + n = b.create(loc, n); + Value n_y = b.create(loc, n, other); + result = b.create(loc, self, n_y); + } else if (newResultType.isa()) { + Value n = b.create(loc, self, other); + Value n_y = b.create(loc, n, other); + result = b.create(loc, self, n_y); + } else { + fmod.emitError("Unsupported type encountered for AtenFmodTensorOp."); + } + return result; + } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() @@ -1279,6 +1419,111 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0], allOnesVal); } + if (isa(op)) { + auto value = payloadArgs[0]; + auto valueTy = value.getType(); + auto qtensor = op->getOperand(0); + auto qtensorTy = qtensor.getType().cast().getDtype(); + + Value zp, scale; + if (auto makeQTensor = + qtensor.getDefiningOp()) { + zp = makeQTensor.getZeroPoint(); + scale = makeQTensor.getScale(); + } + + if (auto quant = qtensor.getDefiningOp()) { + zp = quant.getZeroPoint(); + scale = quant.getScale(); + } + + if (!zp || !scale) { + return nullptr; + } + + auto outFpTy = payloadArgs[1].getType(); + auto outBw = outFpTy.getIntOrFloatBitWidth(); + auto outIntTy = b.getIntegerType(outBw); + + if (valueTy != outIntTy) { + if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { + value = b.create(loc, outIntTy, value); + } else { + value = b.create(loc, outIntTy, value); + } + } + + zp = converter->materializeTargetConversion( + b, loc, converter->convertType(zp.getType()), zp); + auto zpTy = zp.getType(); + + if (zpTy != outIntTy) { + zp = b.create(loc, outIntTy, zp); + } + + value = b.create(loc, value, zp); + + if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { + value = b.create(loc, outFpTy, value); + } else { + value = b.create(loc, outFpTy, value); + } + + scale = converter->materializeTargetConversion( + b, loc, converter->convertType(scale.getType()), scale); + if (scale.getType() != value.getType()) { + scale = b.create(loc, value.getType(), scale); + } + value = b.create(loc, value, scale); + return value; + } + + if (auto quant = dyn_cast(op)) { + Value value = payloadArgs[0]; + Value scale = quant.getScale(); + Value zp = quant.getZeroPoint(); + auto valueTy = value.getType(); + + zp = converter->materializeTargetConversion( + b, loc, converter->convertType(zp.getType()), zp); + zp = b.create(loc, valueTy, zp); + + scale = converter->materializeTargetConversion( + b, loc, converter->convertType(scale.getType()), scale); + scale = b.create(loc, valueTy, scale); + + value = b.create(loc, value, scale); + value = b.create(loc, value); + value = b.create(loc, value, zp); + + auto destTy = payloadArgs[1].getType(); + auto bitwidth = destTy.getIntOrFloatBitWidth(); + bool isUnsigned = torch_to_linalg::isUnsignedTorchType(quant.getType()); + APInt min = isUnsigned ? APInt::getMinValue(bitwidth) + : APInt::getSignedMinValue(bitwidth); + APInt max = isUnsigned ? APInt::getMaxValue(bitwidth) + : APInt::getSignedMaxValue(bitwidth); + + Value minVal = b.create( + loc, b.getFloatAttr(valueTy, min.getSExtValue())); + Value maxVal = b.create( + loc, b.getFloatAttr(valueTy, max.getSExtValue())); + Value minCmp = + b.create(loc, arith::CmpFPredicate::ULT, value, minVal); + Value maxCmp = + b.create(loc, arith::CmpFPredicate::UGT, value, maxVal); + value = b.create(loc, minCmp, minVal, value); + value = b.create(loc, maxCmp, maxVal, value); + + if (isUnsigned) { + value = b.create(loc, destTy, value); + } else { + value = b.create(loc, destTy, value); + } + + return value; + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1311,29 +1556,32 @@ class ConvertElementwiseOp : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa(op)) + if (!isa(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1932,7 +2180,6 @@ class ConvertPrimsCollapseOp : public OpConversionPattern { associations.push_back(ReassociationIndices{i}); } - rewriter.replaceOpWithNewOp( op, resultRankedTensorType, adaptor.getA(), associations); @@ -1959,15 +2206,399 @@ class ConvertTensorStaticInfoCastOp }; } // namespace +namespace { +class ConvertLogitOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenLogitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + Value eps = adaptor.getEps(); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + bool handleEps = false; + if (succeeded(checkNotNone(rewriter, op, eps))) + handleEps = true; + + if (handleEps && !eps.getType().isa()) { + op.emitError("Logit does not support non-floating point type"); + return failure(); + } + + auto inputType = input.getType().cast(); + auto inputElementType = inputType.getElementType(); + + if (!inputElementType.isa()) { + op.emitError("Logit does not support non-floating point type"); + return failure(); + } + + auto inputRank = inputType.getRank(); + + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value logit = + rewriter + .create( + loc, input.getType(), + /*ins=*/input, + /*outs=*/input, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + + TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0); + Value oneValue = b.create(loc, oneAttr); + + Value zI; + if (!handleEps) { + zI = input; + } else { + Value truncEps = + b.create(loc, inputElementType, eps); + Value oneMinusEps = + b.create(loc, oneValue, truncEps); + + Value min = + b.create(loc, input, oneMinusEps); + Value clampedInput = + b.create(loc, min, truncEps); + + zI = clampedInput; + } + + Value probability = + b.create(loc, oneValue, zI); + Value odds = b.create(loc, zI, probability); + Value result = b.create(loc, odds); + + b.create(loc, result); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, logit); + return success(); + } +}; +} // namespace + +namespace { +class ConvertAtenIntReprOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); + return success(); + } +}; +} // namespace + +namespace { +class ConvertDequantizePerChannel + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto qoperand = op.getOperand(); + auto make = qoperand.getDefiningOp(); + if (!make) { + return rewriter.notifyMatchFailure(op, "did not find per channel qint"); + } + + auto converter = getTypeConverter(); + auto operand = make.getOperand(0); + auto scale = make.getScale(); + auto zeropoint = make.getZeroPoint(); + auto axis = make.getAxis(); + + IntegerAttr axisAttr; + if (!matchPattern(axis, m_Constant(&axisAttr))) { + return failure(); + } + + auto operandDTy = operand.getType().cast().getDtype(); + auto zeropointDTy = zeropoint.getType().cast().getDtype(); + operand = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(operand.getType()), operand); + scale = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(scale.getType()), scale); + zeropoint = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint); + + auto resultType = converter->convertType(op->getResult(0).getType()) + .cast(); + + llvm::SmallVector dynSizes; + for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { + if (ShapedType::isDynamic(dim)) { + dynSizes.push_back(rewriter.create(loc, operand, index)); + } + } + + llvm::SmallVector iterators( + resultType.getRank(), utils::IteratorType::parallel); + llvm::SmallVector maps( + 4, {rewriter.getMultiDimIdentityMap(resultType.getRank())}); + auto broadcastMap = AffineMap::get( + resultType.getRank(), /*symbolCount=*/0, + {rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext()); + maps[1] = broadcastMap; + maps[2] = broadcastMap; + + auto empty = + rewriter.create(op.getLoc(), resultType, dynSizes); + auto linalgOp = rewriter.create( + loc, resultType, ValueRange{operand, scale, zeropoint}, + ValueRange{empty}, maps, iterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value operand = args[0]; + Value scale = args[1]; + Value zeropoint = args[2]; + if (operandDTy.isUnsignedInteger(8)) { + operand = b.create(loc, b.getI32Type(), operand); + } else if (operandDTy.isSignedInteger(8)) { + operand = b.create(loc, b.getI32Type(), operand); + } + + if (zeropointDTy.isUnsignedInteger(8)) { + zeropoint = + b.create(loc, b.getI32Type(), zeropoint); + } else if (zeropointDTy.isSignedInteger(8)) { + zeropoint = + b.create(loc, b.getI32Type(), zeropoint); + } + + Value sub = rewriter.create(loc, operand, zeropoint); + Value fp = + rewriter.create(loc, args[3].getType(), sub); + Value mul = rewriter.create(loc, fp, scale); + b.create(loc, mul); + }); + rewriter.replaceOp(op, linalgOp.getResults()); + return success(); + } +}; +} // namespace + +namespace { + +template +class ConvertCastEquivalentOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + + LogicalResult + matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = this->getTypeConverter(); + RankedTensorType resultType = cast( + converter->convertType(op->getResult(0).getType())); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); + return success(); + } +}; +} // namespace + +namespace { +class ConvertAtenGridSamplerOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenGridSamplerOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Type int64type = rewriter.getI64Type(); + Type floatType = rewriter.getF32Type(); + Value zeroIndex = rewriter.create(loc, 0); + Value oneIndex = rewriter.create(loc, 1); + Value twoIndex = rewriter.create(loc, 2); + Value zeroFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 0.0)); + Value oneFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 1.0)); + Value twoFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 2.0)); + Value input = adaptor.getInput(); + auto inputType = input.getType().cast(); + auto inputShape = inputType.getShape(); + Value innerDim0a = rewriter.create(loc, input, 2); + Value innerDim1a = rewriter.create(loc, input, 3); + Value innerDim0b = + rewriter.create(loc, innerDim0a, oneIndex); + Value innerDim1b = + rewriter.create(loc, innerDim1a, oneIndex); + Value innerDim0c = + rewriter.create(loc, int64type, innerDim0b); + Value innerDim1c = + rewriter.create(loc, int64type, innerDim1b); + Value innerDim0d = + rewriter.create(loc, floatType, innerDim0c); + Value innerDim1d = + rewriter.create(loc, floatType, innerDim1c); + Value innerDim0e = + rewriter.create(loc, innerDim0d, twoFloat); + Value innerDim1e = + rewriter.create(loc, innerDim1d, twoFloat); + Value grid = adaptor.getGrid(); + auto gridType = grid.getType().cast(); + auto gridShape = gridType.getShape(); + auto gridRank = gridType.getRank(); + SmallVector extractGridOffsets0(gridRank, zeroIndex); + SmallVector extractGridShape = getTensorSizes(rewriter, loc, grid); + SmallVector extractGridStride(gridRank, oneIndex); + int64_t lastGridDim = gridRank - 1; + extractGridShape[lastGridDim] = oneIndex; + extractGridStride[lastGridDim] = twoIndex; + SmallVector extractGridOffsets1(gridRank, zeroIndex); + extractGridOffsets1[lastGridDim] = oneIndex; + SmallVector gridShapeExtracted(gridShape); + gridShapeExtracted.back() = 1; + SmallVector gridShapeCollapsed{gridShape[0], gridShape[1], + gridShape[2]}; + auto grid0 = rewriter.create( + loc, grid, extractGridOffsets0, extractGridShape, extractGridStride); + auto grid1 = rewriter.create( + loc, grid, extractGridOffsets1, extractGridShape, extractGridStride); + SmallVector associations{ReassociationIndices{0}, + ReassociationIndices{1}, + ReassociationIndices{2, 3}}; + auto gridCollapsed0 = + rewriter.create(loc, grid0, associations); + auto gridCollapsed1 = + rewriter.create(loc, grid1, associations); + AffineMap gridMap = AffineMap::get(4, 0, + {rewriter.getAffineDimExpr(0), + rewriter.getAffineDimExpr(2), + rewriter.getAffineDimExpr(3)}, + op->getContext()); + SmallVector gridMaps{gridMap, gridMap, + rewriter.getMultiDimIdentityMap(gridRank)}; + SmallVector gridIterators( + gridRank, utils::IteratorType::parallel); + SmallVector resultShape{inputShape[0], inputShape[1], gridShape[1], + gridShape[2]}; + auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, + Value idxB, Value idxC, Value idxD) -> Value { + SmallVector index{idxA, idxB, idxC, idxD}; + Value result = b.create(loc, input, index); + return result; + }; + auto lambdaInter = [&](OpBuilder &b, Location loc, Value x, Value y, + Value d) -> Value { + Value dm = b.create(loc, oneFloat, d); + Value ra = b.create(loc, x, dm); + Value rb = b.create(loc, y, d); + Value res = b.create(loc, ra, rb); + return res; + }; + auto resultType = getTypeConverter() + ->convertType(op.getResult().getType()) + .cast(); + SmallVector resultSize{}; + if (resultType.isDynamicDim(0)) + resultSize.push_back(rewriter.create(loc, input, 0)); + if (resultType.isDynamicDim(1)) + resultSize.push_back(rewriter.create(loc, input, 1)); + if (resultType.isDynamicDim(2)) + resultSize.push_back(rewriter.create(loc, grid, 1)); + if (resultType.isDynamicDim(3)) + resultSize.push_back(rewriter.create(loc, grid, 2)); + Value resultFinal = + rewriter.create(loc, resultType, resultSize); + auto sGrid = rewriter.create( + loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, + ValueRange(resultFinal), gridMaps, gridIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value gr0 = args[1]; + Value gr1 = args[0]; + Value gplus0 = b.create(loc, gr0, oneFloat); + Value gplus1 = b.create(loc, gr1, oneFloat); + Value result0 = b.create(loc, gplus0, innerDim0e); + Value result1 = b.create(loc, gplus1, innerDim1e); + Value lower0 = b.create(loc, int64type, result0); + Value lower1 = b.create(loc, int64type, result1); + Value oneInt = + b.create(loc, b.getIntegerAttr(int64type, 1)); + Value upper0 = + b.create(loc, int64type, lower0, oneInt); + Value upper1 = + b.create(loc, int64type, lower1, oneInt); + Value notValid0 = rewriter.create( + loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); + Value notValid1 = rewriter.create( + loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); + Value upperValid0 = + b.create(loc, notValid0, lower0, upper0); + Value upperValid1 = + b.create(loc, notValid1, lower1, upper1); + Value lw0 = + b.create(loc, b.getIndexType(), lower0); + Value lw1 = + b.create(loc, b.getIndexType(), lower1); + Value up0 = + b.create(loc, b.getIndexType(), upperValid0); + Value up1 = + b.create(loc, b.getIndexType(), upperValid1); + Value N = b.create(loc, 0); + Value C = b.create(loc, 1); + Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); + Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); + Value result01a = + b.create(loc, notValid1, zeroFloat, result01); + Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); + Value result10a = + b.create(loc, notValid0, zeroFloat, result10); + Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); + Value result11a = + b.create(loc, notValid0, zeroFloat, result11); + Value result11b = + b.create(loc, notValid1, zeroFloat, result11a); + Value lw0a = b.create(loc, floatType, lower0); + Value lw1a = b.create(loc, floatType, lower1); + Value d1 = b.create(loc, result0, lw0a); + Value d0 = b.create(loc, result1, lw1a); + Value resultScaled0 = lambdaInter(b, loc, result00, result01a, d0); + Value resultScaled1 = lambdaInter(b, loc, result10a, result11b, d0); + Value resultScaled = + lambdaInter(b, loc, resultScaled0, resultScaled1, d1); + b.create(loc, resultScaled); + }); + rewriter.replaceOp(op, sGrid.getResults()); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp< - AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, - AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, - AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op, - AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, + AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, + AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, + AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, + AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, @@ -1980,9 +2611,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, - AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp>(); + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, + AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -1990,6 +2623,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); @@ -1998,4 +2633,15 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index ccc78985dc6c..366f5492aa6d 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -7,13 +7,13 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -70,7 +70,7 @@ Value torch_to_linalg::getZeroPaddedTensor( // padding value is zero. Value torch_to_linalg::getDynamicZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &padding, - int unpaddedDims) { + int unpaddedDims, Value pad) { assert(input.getType().isa() && "input must be RankedTensorType"); unsigned int inRank = input.getType().cast().getRank(); @@ -87,17 +87,16 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( *pad = castIntToIndex(b, loc, *pad); Type elementType = input.getType().cast().getElementType(); + // TODO: audit possibility of sparsity on this tensor Type inputType = RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( SmallVector(inRank, kUnknownSize))), elementType); - Value cf0 = - b.create(loc, b.getFloatAttr(elementType, 0.0)); SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); return b.create(loc, inputType, input, /*low=*/paddingValues, - /*high=*/paddingValues, cf0); + /*high=*/paddingValues, pad); } Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, @@ -198,7 +197,8 @@ Value torch_to_linalg::createReductionLinalgGeneric( } } - auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); + auto indexingMaps = + AffineMap::inferFromExprList({exprs, resultExprs}, b.getContext()); Value accumulator = createInitTensor(b, loc, resultShape, initElem.getType(), initElem); @@ -559,3 +559,20 @@ FailureOr torch_to_linalg::getBackendTypeForScalarType( } return type; } + +bool torch_to_linalg::isUnsignedTorchType(Type type) { + if (auto tty = dyn_cast(type)) + return isUnsignedTorchType(tty.getDtype()); + if (isa(type)) + return false; + if (isa(type)) + return false; + if (isa(type)) + return true; + if (isa(type)) + return false; + if (auto intTy = dyn_cast(type)) + return intTy.isUnsigned(); + llvm_unreachable("Unknown type checked for signedness"); + return false; +} diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index f0dc4aaf2dfa..22743e6a9dee 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -11,6 +11,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" +#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -377,12 +378,12 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); @@ -424,7 +425,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), outElemTy); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); @@ -542,7 +543,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { } else { return op.emitError("operator haven't been supported"); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); @@ -570,7 +571,7 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { Value rhs = hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); @@ -757,7 +758,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, - rewriter.getI64TensorAttr(dimensionNumbers)); + rewriter.getDenseI64ArrayAttr(dimensionNumbers)); } return success(); } @@ -887,7 +888,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); @@ -923,8 +924,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.getA().getType().template cast().getDtype(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); - auto result = - rewriter.create(loc, adaptor.getA()); + auto result = rewriter.create(loc, adaptor.getA()); rewriter.replaceOp( op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); @@ -1479,7 +1479,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value window = rewriter.create(loc, outType, resultLength, 0); - DenseIntElementsAttr broadcastDimensions; + DenseI64ArrayAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); rewriter.replaceOpWithNewOp(op, mulOut, start, @@ -1663,19 +1663,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt); + FailureOr maybeResultElementType = + torch_to_stablehlo::getBackendTypeForScalarType( + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; - // The stablehlo backend expects signed integers to be signless. - if (resultElementType.isSignedInteger()) { - resultElementType = IntegerType::get( - op->getContext(), resultElementType.getIntOrFloatBitWidth(), - IntegerType::Signless); - } } // Create an uninitialized tensor of `resultSize` shape. @@ -1722,7 +1717,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), adaptor.getSelf()); Value bcastScalar = rewriter.create( op->getLoc(), outType, scalarTensor, shapeTensor, - rewriter.getI64TensorAttr({})); + rewriter.getDenseI64ArrayAttr({})); rewriter.replaceOp(op, bcastScalar); return success(); } @@ -1764,7 +1759,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) - INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp); INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); @@ -1785,6 +1779,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); + INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ @@ -1797,8 +1792,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context) + patterns.add>(typeConverter, context) INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp); INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp); diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 07ef1e2ea661..566f1d15b6ad 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo ViewLike.cpp Reduction.cpp Pooling.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 9c8123bfdbad..53c418da4fb9 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -30,8 +30,8 @@ using namespace mlir::torch::torch_to_stablehlo; namespace { static Value createInitialValueForGatherScatterOp(Operation *op, - RankedTensorType constType, - PatternRewriter &rewriter) { + RankedTensorType constType, + PatternRewriter &rewriter) { auto elementTy = constType.getElementType(); if (isa(op)) { if (elementTy.isa()) { @@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return failure(); auto stablehloReduceOp = rewriter.create( - op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0})); + op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}), + elementTy); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -510,7 +511,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, input, gatherIndicies, dimsAttr, - rewriter.getI64TensorAttr(sliceSizes)); + rewriter.getDenseI64ArrayAttr(sliceSizes)); return success(); } @@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( /*indexVectorDim=*/indexVecDim); auto stablehloScatterOp = rewriter.create( - loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false); + loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers, + false, false); // config update computation function: just return the element from src. Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock(); @@ -833,7 +835,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, resultType, input, finalIndexTensor, dimsAttr, - rewriter.getI64TensorAttr(sliceSizes)); + rewriter.getDenseI64ArrayAttr(sliceSizes)); return success(); } diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index df92317824a1..b1749ee1c074 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -39,10 +39,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, RankedTensorType outTy = RankedTensorType::get(shape, tensorTy.getElementType()); - RankedTensorType attrTy = - RankedTensorType::get({static_cast(broadcastDims.size())}, - rewriter.getIntegerType(64)); - auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); + auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims); auto broadcast = rewriter.create( loc, outTy, tensor, stablehloShape, broadcastAttr); @@ -549,8 +546,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); - DenseIntElementsAttr stablehloStride = - rewriter.getI64TensorAttr(stablehloStrideVec); + auto stablehloStride = rewriter.getDenseI64ArrayAttr(stablehloStrideVec); SmallVector stablehloPaddingVec(nSpatialDims * 2, 0); for (int i = 0; i < nSpatialDims; ++i) { int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; @@ -563,15 +559,15 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { stablehloPaddingVec); SmallVector stablehloLhsDilationVec(nSpatialDims); std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin()); - DenseIntElementsAttr stablehloLhsDilation = - rewriter.getI64TensorAttr(stablehloLhsDilationVec); + auto stablehloLhsDilation = + rewriter.getDenseI64ArrayAttr(stablehloLhsDilationVec); SmallVector stablehloRhsDilationVec(nSpatialDims); std::copy(dilation.begin(), dilation.end(), stablehloRhsDilationVec.begin()); - DenseIntElementsAttr stablehloRhsDilation = - rewriter.getI64TensorAttr(stablehloRhsDilationVec); + auto stablehloRhsDilation = + rewriter.getDenseI64ArrayAttr(stablehloRhsDilationVec); - DenseElementsAttr windowReversal; + DenseBoolArrayAttr windowReversal; ArrayAttr precisionConfig; SmallVector spatialDims; @@ -614,10 +610,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { int64_t nDims = outType.getRank(); // Get stablehlo::ConvolutionOp attributes - DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stride.size())}, - rewriter.getI64Type()), - stride); + auto stablehloWindowStride = rewriter.getDenseI64ArrayAttr(stride); std::vector stablehloPaddingVec; for (size_t i = 0; i < padding.size(); i++) { stablehloPaddingVec.emplace_back(padding[i]); @@ -628,10 +621,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { {static_cast(padding.size()), static_cast(2)}, rewriter.getI64Type()), stablehloPaddingVec); - DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(dilation.size())}, - rewriter.getI64Type()), - dilation); + auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(dilation); SmallVector spatialDimensions; for (int64_t i = 2; i < nDims; i++) { spatialDimensions.emplace_back(i); @@ -648,8 +638,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { /*outputSpatialDimensions=*/spatialDimensions); // stablehlo::ConvolutionOp's optional attributes, leave them as default - DenseIntElementsAttr stablehloLhsDilation; - DenseElementsAttr windowReversal; + DenseI64ArrayAttr stablehloLhsDilation; + DenseBoolArrayAttr windowReversal; ArrayAttr precisionConfig; auto stablehloConvOp = rewriter.create( @@ -781,7 +771,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { options.dimSizeIndexBits); bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, stablehloConvResult, bias, bcastDimensions); return success(); diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7c28a2fd3004..40b0dd691071 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -35,7 +35,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); // Avg pooling - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -135,19 +136,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -241,19 +233,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -373,7 +356,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - namespace { template class ConvertAtenAvgPoolOp : public ConvertAtenOp { @@ -388,45 +370,45 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { Type inputElemTy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); RankedTensorType outTy = ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + ->convertType(op.getType()) + .template cast(); auto outShape = outTy.getShape(); - if (inputRank <= Dim) { - return op.emitError( - "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); + return op.emitError( + "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; bool countIncludePad = true; if (!(matchPattern(op.getKernelSize(), - m_TorchListOfConstantInts(kernelSize)))) { - return rewriter.notifyMatchFailure( - op, "non-const int kernel size unsupported!"); + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { - return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + return rewriter.notifyMatchFailure(op, + "non-const int stride unsupported!"); } if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { - return rewriter.notifyMatchFailure(op, - "non-const int padding unsupported!"); + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); } if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { - return rewriter.notifyMatchFailure(op, - "non-const bool ceil_mode unsupported!"); + return rewriter.notifyMatchFailure( + op, "non-const bool ceil_mode unsupported!"); } if (!(matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)))) { - return rewriter.notifyMatchFailure( - op, "non-const bool count_include_pad unsupported!"); + m_TorchConstantBool(&countIncludePad)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool count_include_pad unsupported!"); } if constexpr (std::is_same()) { - if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) - return rewriter.notifyMatchFailure( - op, "only None divisor_override supported for now!"); + if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) + return rewriter.notifyMatchFailure( + op, "only None divisor_override supported for now!"); } // Prepend 1 to kernelSize, stride, dilation until they are of same rank @@ -437,34 +419,26 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { SmallVector stablehloPadding(inputRank * 2, 0); std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - Dim); + stablehloStride.begin() + inputRank - Dim); std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - Dim); + stablehloKernelSize.begin() + inputRank - Dim); if (Dim == 1) { - stablehloPadding[stablehloPadding.size() - 2] = padding[0]; - stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; } else { - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; } - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + Value initVal = + createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -485,31 +459,31 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { auto secondArg = *sumBlock.args_rbegin(); { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sumBlock); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sumBlock); - Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); } // Use kernel size as the divisor if (countIncludePad) { - Value divisor; - if (Dim == 1) { - divisor = - hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) - .value(); - } else { - divisor = hlo::getConstTensor( - rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) - .value(); - } - divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); - DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp( - op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); - return success(); + Value divisor; + if (Dim == 1) { + divisor = + hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) + .value(); + } else { + divisor = hlo::getConstTensor( + rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) + .value(); + } + divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); + DenseI64ArrayAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); + return success(); } // Use another mhlo.ReduceWindowOp to get the divisor @@ -518,15 +492,15 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); const auto &options = ConvertAtenOp::getOptions(); - auto inputShapeVec = - *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); windowSizeConst = rewriter.create( op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), - windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); + windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); auto reduceWindowSize = rewriter.create( @@ -544,23 +518,20 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { secondArg = *sizeBlock.args_rbegin(); { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sizeBlock); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sizeBlock); - Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); } rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); return success(); - } - }; -} - +} // namespace // AtenCumsumOp template <> @@ -569,11 +540,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto inputShape = inputTy.getShape(); - auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { @@ -598,19 +571,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector stablehloPadding(inputRank * 2, 0); stablehloPadding[dim * 2] = inputShape[dim] - 1; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -658,10 +622,10 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); -#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ +#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ target.addIllegalOp(); \ - patterns.add>( \ - typeConverter, context, options) + patterns.add>(typeConverter, context, \ + options) INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); #undef INSERT_ATEN_AVGPOOL_PATTERN diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 36f4d49e9a99..0b27d0748855 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -16,13 +16,16 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" + +#include +#include using namespace mlir; using namespace mlir::torch; @@ -116,6 +119,12 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } + std::vector outputShape(inputShape.begin(), inputShape.end()); + outputShape.erase(outputShape.begin() + dim); + auto outputTy = RankedTensorType::get(outputShape, inputElemTy); + auto outputIndexTy = + RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); + auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( @@ -125,12 +134,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, inputShapeTensor, static_cast(dim)); auto stablehloReduceOp = rewriter.create( - op->getLoc(), ValueRange{input, indexTensor}, + op->getLoc(), TypeRange{outputTy, outputIndexTy}, + ValueRange{input, indexTensor}, ValueRange{ initValue, initIndex, }, - rewriter.getI64TensorAttr(dim)); + rewriter.getDenseI64ArrayAttr(dim)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); @@ -412,7 +422,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -473,7 +484,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -535,7 +547,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -614,6 +627,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputTy.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); @@ -625,7 +646,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), + RankedTensorType::get(reduceResultShape, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -714,6 +737,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( // stable with unordered dims. std::sort(dims.begin(), dims.end()); + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputRank; i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputType.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( @@ -728,8 +759,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto reduceOp = rewriter.create( - op->getLoc(), squareOp.getResult(), initValue, - rewriter.getI64TensorAttr(dims)); + op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType), + squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -832,6 +863,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( std::sort(dims.begin(), dims.end()); } + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputType.getRank(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputType.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( @@ -848,7 +887,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ord, nullptr); auto reduceOp = rewriter.create( - op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims)); + op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType), + powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index ed203cb0f91f..c3f8eff22fbc 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -241,10 +241,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, if (!do_bcast) { return input; } - DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(bcastDims.size())}, - rewriter.getI64Type()), - bcastDims); + auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims); auto bcast_op = rewriter.create( op->getLoc(), outType, input, bcast_attr); return bcast_op.getResult(); @@ -360,7 +357,7 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, auto constTensor = rewriter.create(loc, constAttr); return rewriter .create( - loc, outType, constTensor, shape, rewriter.getI64TensorAttr({})) + loc, outType, constTensor, shape, rewriter.getDenseI64ArrayAttr({})) .getResult(); } } // namespace hlo diff --git a/lib/Conversion/TorchToStablehlo/Utils.cpp b/lib/Conversion/TorchToStablehlo/Utils.cpp new file mode 100644 index 000000000000..390888750110 --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Utils.cpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "./Utils.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace torch; + +FailureOr torch_to_stablehlo::getBackendTypeForScalarType( + MLIRContext *context, torch_upstream::ScalarType dtypeInt) { + FailureOr maybeType = Torch::getTypeForScalarType( + context, (torch_upstream::ScalarType)dtypeInt); + if (failed(maybeType)) { + return failure(); + } + Type type = *maybeType; + // The stablehlo backend expects signed integers to be signless. + if (type.isSignedInteger()) { + type = IntegerType::get(context, type.getIntOrFloatBitWidth(), + IntegerType::Signless); + } + return type; +} diff --git a/lib/Conversion/TorchToStablehlo/Utils.h b/lib/Conversion/TorchToStablehlo/Utils.h new file mode 100644 index 000000000000..16788e3955c4 --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Utils.h @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" + +namespace mlir { +namespace torch { +namespace torch_to_stablehlo { + +// Convert a scalar type to the corresponding builtin type in the +// stablehlo backend. +FailureOr +getBackendTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt); + +} // namespace torch_to_stablehlo +} // namespace torch +} // namespace mlir diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index ea19092e6c8b..507821dee638 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -22,7 +23,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include using namespace mlir; @@ -403,7 +403,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + int64_t inputRank = + adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index d11a5524af7d..ebbb0b5f362b 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -200,15 +200,30 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, scatterInputsVector[indexType.getRank()]); } +static llvm::SmallVector createDefaultDimMap(Value indices) { + llvm::SmallVector dmap; + if (auto iTy = dyn_cast(indices.getType())) + dmap.resize(iTy.getSizes()[1]); + + if (auto iTy = dyn_cast(indices.getType())) + dmap.resize(iTy.getDimSize(1)); + + for (int i = 0, s = dmap.size(); i < s; ++i) + dmap[i] = i; + + return dmap; +} + static Value createTMTensorScatterOp( OpBuilder &b, Location loc, Value updates, Value indices, Value original, - bool uniqueIndices, + llvm::ArrayRef dimensionsMap, bool uniqueIndices, function_ref bodyBuild) { + auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap); auto originalTensorType = original.getType().cast(); Type originalElementType = originalTensorType.getElementType(); auto scatterOp = b.create( loc, originalTensorType, ValueRange{updates, indices}, - ValueRange{original}, uniqueIndices); + ValueRange{original}, dimensionsMapAttr, uniqueIndices); Region &scatterOpRegion = scatterOp.getRegion(); auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); @@ -334,7 +349,7 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { src, dim); Value scatterOp = createTMTensorScatterOp( rewriter, loc, updates, indices, self, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { b.create(loc, updatesElement); @@ -455,7 +470,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { Value scatterOp = createTMTensorScatterOp( rewriter, loc, updatesTensor, indices, bincountTensor, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value _, Value bincountElem) { Value add = b.create(loc, bincountElem, constantOne); b.create(loc, add); @@ -466,235 +481,200 @@ class ConvertAtenBincountOp : public OpConversionPattern { }; } // namespace -// """Create a map from each dimension of the input tensor to the -// subspace that dimension corresponds to in the result shape one gets -// from indexing the tensor with the optional index tensors. -// -// Note: Index tensors are first broadcasted to a common shape before -// creating the mapping. So the index of every index tensor will map to -// the same dimensions in the result shape. -// -// For example: -// indices = [None, None, torch.randint(4, (6, 1)), torch.randint(5, (7,))] -// indexBroadcastShapeValue = [6, 7] -// map = {0: [0], 1: [1], 2: [2, 3], 3: [2, 3]} -static SmallVector> -getInputShapeToOutputShapeMap(SmallVector optionalIndices, - SmallVector indexBroadcastShapeValue) { - SmallVector indices; - for (Value index : optionalIndices) { - if (!index.getType().isa()) - indices.push_back(index); +namespace { + +Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, + OpBuilder b) { + llvm::SmallVector indices(indicesRef); + // Declare commonly used constants up front: + Value torchCstZero = + b.create(loc, b.getI64IntegerAttr(0)); + Value torchCstOne = + b.create(loc, b.getI64IntegerAttr(1)); + Value torchCstNegOne = + b.create(loc, b.getI64IntegerAttr(-1)); + + // Determine the broadcast sizes and materialize missing implicit end + // dimensions: + int64_t indicesRank = 0; + for (auto index : indices) { + auto indexTy = cast(index.getType()); + int64_t rank = indexTy.getSizes().size(); + indicesRank = std::max(rank, indicesRank); } - unsigned broadcastRank = indexBroadcastShapeValue.size(); - unsigned numIndexTensors = indices.size(); - int64_t indexOfFirstIndexTensor = -1; - SmallVector> result; - - for (unsigned i = 0; i < optionalIndices.size(); i++) { - if (optionalIndices[i].getType().isa()) { - unsigned val = i; - if (indexOfFirstIndexTensor >= 0) - val += broadcastRank - numIndexTensors; - result.push_back({val}); - } else { - if (indexOfFirstIndexTensor < 0) - indexOfFirstIndexTensor = i; - SmallVector outputIndices; - for (unsigned j = indexOfFirstIndexTensor; - j < (indexOfFirstIndexTensor + broadcastRank); j++) - outputIndices.push_back(j); - result.push_back(outputIndices); + auto maxDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return std::max(dim0, dim1); + }; + + llvm::SmallVector broadcastSizes(indicesRank, torchCstOne); + llvm::SmallVector broadcastShape(indicesRank, 0); + for (auto index : indices) { + auto indexTy = cast(index.getType()); + auto shape = indexTy.getSizes(); + int32_t rank = shape.size(); + + for (int32_t j = 0; j < rank; ++j) { + Value dim = b.create(loc, b.getI64IntegerAttr(j)); + auto sizeOp = b.create(loc, index, dim); + auto size = shape[j]; + + int32_t idx = broadcastShape.size() - rank + j; + broadcastSizes[idx] = + b.create(loc, sizeOp, broadcastSizes[idx]); + broadcastShape[idx] = maxDim(size, broadcastShape[idx]); } } - return result; -} -static std::tuple, SmallVector> -getIndicesFinalShape(ConversionPatternRewriter &rewriter, Location loc, - Value input, SmallVector optionalIndices, - SmallVector inputShapeInt, - SmallVector inputShapeValue, - SmallVector indexBroadcastShapeInt, - SmallVector indexBroadcastShapeValue) { - SmallVector result; - SmallVector resultInt; - bool handledIndexTensorSpace = false; - - for (unsigned i = 0; i < inputShapeValue.size(); i++) { - if (optionalIndices[i].getType().isa()) { - result.push_back(inputShapeValue[i]); - resultInt.push_back(inputShapeInt[i]); - } else { - if (!handledIndexTensorSpace) { - handledIndexTensorSpace = true; - for (unsigned j = 0; j < indexBroadcastShapeValue.size(); j++) { - result.push_back(indexBroadcastShapeValue[j]); - resultInt.push_back(indexBroadcastShapeInt[j]); - } - } - } + auto mulDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return dim0 * dim1; + }; + + int64_t scatterBatchCount = 1; + for (auto dim : broadcastShape) { + scatterBatchCount = mulDim(scatterBatchCount, dim); } - return std::make_tuple(result, resultInt); -} -static FailureOr -getScatterIndices(Aten_IndexPutImplOp op, ConversionPatternRewriter &rewriter, - Type indicesDtype, SmallVector optionalIndices, - SmallVector indexBroadcastShapeInt, - SmallVector indexBroadcastShapeValue) { - Location loc = op.getLoc(); - MLIRContext *context = op->getContext(); - Value input = op.getSelf(); - - SmallVector> shapeMap = - getInputShapeToOutputShapeMap(optionalIndices, indexBroadcastShapeValue); - - SmallVector inputShapeInt{ - input.getType().cast().getSizes()}; - int64_t inputRank = inputShapeInt.size(); - SmallVector inputShapeValue; - for (unsigned i = 0; i < inputShapeInt.size(); i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - inputShapeValue.push_back( - rewriter.createOrFold(loc, input, dim)); + // Broadcast together and flatten to batch values: + Value broadcastSizeList = b.create( + loc, Torch::ListType::get(b.getType()), broadcastSizes); + for (Value &index : indices) { + auto indexTy = cast(index.getType()); + auto expandTy = b.getType( + broadcastShape, indexTy.getOptionalDtype()); + index = b.create(loc, expandTy, index, + broadcastSizeList); + + auto flattenTy = b.getType( + scatterBatchCount, indexTy.getOptionalDtype()); + index = b.create( + loc, flattenTy, index, torchCstZero, torchCstNegOne); } - auto finalShapeResult = getIndicesFinalShape( - rewriter, loc, input, optionalIndices, inputShapeInt, inputShapeValue, - indexBroadcastShapeInt, indexBroadcastShapeValue); - SmallVector finalShapeValue = std::get<0>(finalShapeResult); - SmallVector finalShapeInt = std::get<1>(finalShapeResult); + // Unsqueeze so we have a 1 dim to concat along: + for (Value &tensor : indices) { + auto btt = cast(tensor.getType()); + if (!btt.hasSizes()) + return nullptr; - Value torchCstNone = rewriter.create(loc); - Value torchCstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value torchCstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - - Value indexBroadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - indexBroadcastShapeValue); - - // Calculating index count. - int64_t indexCount = 1; - if (llvm::all_of(finalShapeInt, - [](int64_t shape) { return shape != kUnknownSize; })) { - for (int64_t i : finalShapeInt) - indexCount *= i; - } else { - indexCount = kUnknownSize; + llvm::SmallVector shape(btt.getSizes()); + shape.push_back(1); + + auto unsqueezeTy = b.getType(shape, btt.getDtype()); + Value unsqueezed = + b.create(loc, unsqueezeTy, tensor, torchCstOne); + tensor = unsqueezed; } - Value indexCountValue = finalShapeValue[0]; - for (unsigned i = 1; i < finalShapeValue.size(); i++) - indexCountValue = - rewriter.create(loc, indexCountValue, finalShapeValue[i]); - - ValueTensorType flattenIndicesType = - ValueTensorType::get(context, llvm::ArrayRef(indexCount), indicesDtype); - Value flattenEndDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(finalShapeInt.size() - 1)); - - SmallVector broadcastedIndices; - for (unsigned i = 0; i < optionalIndices.size(); i++) { - Value broadcastedIndexTensor; - if (optionalIndices[i].getType().isa()) { - Value torchCstDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - Value inputDim = rewriter.create(loc, input, torchCstDim); - ValueTensorType tensorType = ValueTensorType::get( - context, llvm::ArrayRef(inputShapeInt[i]), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, tensorType, /*start=*/torchCstZero, /*end=*/inputDim, - /*step=*/torchCstOne, - /*dtype=*/torchCstNone, - /*layout=*/torchCstNone, - /*device=*/torchCstNone, - /*pin_memory=*/torchCstNone); - } else { - ValueTensorType tensorType = ValueTensorType::get( - context, llvm::ArrayRef(indexBroadcastShapeInt), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, tensorType, optionalIndices[i], indexBroadcastShapeTorchList); - } + BaseTensorType unsqueezedTensorType = + indices[0].getType().cast(); + Value indicesTorchList = b.create( + loc, Torch::ListType::get(unsqueezedTensorType), indices); + llvm::SmallVector concatShape{ + unsqueezedTensorType.getSizes()[0], static_cast(indices.size())}; + ValueTensorType concatIndicesType = b.getType( + llvm::ArrayRef(concatShape), unsqueezedTensorType.getDtype()); + return b.create(loc, concatIndicesType, indicesTorchList, + torchCstOne); +} - // spotlight_indices(final_shape, shape_map[i]): - // Turn all values in `final_shape` to `1` except for those with index in - // `indices`. - // for j in range(len(final_shape)): - // if j not in indices: - // final_shape[j] = 1 - // This is equivalent to unsqueezing the index tensor at the dimension `j` - // not in indices. - for (unsigned j = 0; j < finalShapeInt.size(); j++) { - if (llvm::find(shapeMap[i], j) == shapeMap[i].end()) { - Value unsqueezeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(j)); - auto unsqueezedInfo = - unsqueezeTensor(rewriter, op, broadcastedIndexTensor, - /*dim=*/unsqueezeDim); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure( - op, "cannot generate unsqueeze tensor op"); - } - broadcastedIndexTensor = *unsqueezedInfo; - } - } +// Helper that collapses the batch dimensions together and moves it to the front +// of the array. +static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, + int64_t count, OpBuilder b) { + if (batch == 0 && count == 1) + return values; + + auto valuesTy = cast(values.getType()); + auto inShape = valuesTy.getSizes(); - // Performing broadcast to final shape. - Value broadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - finalShapeValue); - ValueTensorType broadcastTensorType = ValueTensorType::get( - context, llvm::ArrayRef(finalShapeInt), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, broadcastTensorType, broadcastedIndexTensor, - broadcastShapeTorchList); - - // Flattening the tensor. - broadcastedIndexTensor = rewriter.create( - loc, flattenIndicesType, broadcastedIndexTensor, torchCstZero, - flattenEndDim); - - broadcastedIndices.push_back(broadcastedIndexTensor); + llvm::SmallVector outShape; + llvm::SmallVector outDims; + + // We need a length-1 dim at the start to transpose the batch to: + if (batch != 0) { + outDims.push_back(b.create(loc, 1)); + outShape.push_back(1); } - // Stacking broadcasted indices. - Value scatterIndices; - // The operation torch.stack([a, b], dim=0) is decomposed into: - // torch.cat([a.unsqueeze(dim=0), b.unsqueeze(dim=0)], dim=0) - // Unsqueeze all tensors before concatenating. - SmallVector unsqueezedIndexTensors; - for (Value tensor : broadcastedIndices) { - auto unsqueezedInfo = - unsqueezeTensor(rewriter, op, tensor, /*dim=*/torchCstZero); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor op"); - } - unsqueezedIndexTensors.push_back(*unsqueezedInfo); + // Dimensions before the batch stay the same: + for (int i = 0; i <= batch; i++) { + auto k = b.create(loc, b.getI64IntegerAttr(i)); + auto dim = b.create(loc, values, k); + outDims.push_back(dim); + outShape.push_back(inShape[i]); } - BaseTensorType unsqueezedTensorType = - unsqueezedIndexTensors[0].getType().cast(); - Value concatIndicesTorchList = rewriter.create( - loc, Torch::ListType::get(unsqueezedTensorType), unsqueezedIndexTensors); - ValueTensorType concatIndicesType = ValueTensorType::get( - context, llvm::ArrayRef({inputRank, indexCount}), indicesDtype); - scatterIndices = rewriter.create( - loc, concatIndicesType, concatIndicesTorchList, torchCstZero); - - ValueTensorType transposedIndicesType = ValueTensorType::get( - context, llvm::ArrayRef({indexCount, inputRank}), indicesDtype); - scatterIndices = rewriter.create( - loc, transposedIndicesType, scatterIndices, torchCstZero, torchCstOne); - return scatterIndices; + auto mulI = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return dim0 * dim1; + }; + + // Determine the collapse size of the batch dimension: + for (int i = 1; i < count; i++) { + outShape.back() = mulI(outShape.back(), inShape[batch + i]); + + auto k = + b.create(loc, b.getI64IntegerAttr(batch + i)); + auto dim = b.create(loc, values, k); + outDims.back() = b.create(loc, dim, outDims.back()); + } + + // Add the dimensions after the batch dims: + for (int i = batch + count, s = inShape.size(); i < s; ++i) { + auto k = b.create(loc, b.getI64IntegerAttr(i)); + auto dim = b.create(loc, values, k); + outDims.push_back(dim); + outShape.push_back(inShape[i]); + } + + Value outDimsList = b.create( + loc, Torch::ListType::get(b.getType()), outDims); + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + values = b.create(loc, valuesTy, values, outDimsList); + + if (batch == 0) + return values; + + // Batch is already at the front, no need to transpose: + std::swap(outDims[0], outDims[batch + 1]); + std::swap(outShape[0], outShape[batch + 1]); + + Value dim0 = b.create(loc, b.getI64IntegerAttr(0)); + Value dimB = + b.create(loc, b.getI64IntegerAttr(batch + 1)); + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + values = + b.create(loc, valuesTy, values, dim0, dimB); + + outDims.clear(); + outShape.clear(); + auto transposeShape = valuesTy.getSizes(); + int64_t transposeRank = transposeShape.size(); + for (int i = 0; i < transposeRank; ++i) { + if (i == batch + 1) + continue; + Value k = b.create(loc, b.getI64IntegerAttr(i)); + outDims.push_back(b.create(loc, values, k)); + outShape.push_back(transposeShape[i]); + } + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + outDimsList = b.create( + loc, Torch::ListType::get(b.getType()), outDims); + return b.create(loc, valuesTy, values, outDimsList); } -namespace { class ConvertAten_IndexPutImplOp : public OpConversionPattern { public: @@ -706,11 +686,11 @@ class ConvertAten_IndexPutImplOp return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); - Value input = adaptor.getSelf(); - Value values = adaptor.getValues(); - RankedTensorType inputType = input.getType().cast(); - RankedTensorType valuesType = values.getType().cast(); - int64_t inputRank = inputType.getRank(); + Value input = op.getSelf(); + Value values = op.getValues(); + auto inputType = cast(input.getType()); + auto valuesType = cast(values.getType()); + int64_t inputRank = inputType.getSizes().size(); auto valuesTensorType = op.getValues().getType().cast(); auto resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); @@ -737,190 +717,111 @@ class ConvertAten_IndexPutImplOp op, "Expected accumulate to be constant bool."); // The element type of the `input` and `values` should be same. - if (inputType.getElementType() != valuesType.getElementType()) + if (inputType.getDtype() != valuesType.getDtype()) return rewriter.notifyMatchFailure( op, "Input element type should be same as the values element type."); + if (valuesType.getSizes().empty()) + return rewriter.notifyMatchFailure( + op, "not implemented"); + SmallVector optionalIndicesList; getListConstructElements(op.getIndices(), optionalIndicesList); + int64_t optionalIndicesCount = optionalIndicesList.size(); // The size of the list of the index tensors should not be greater than the // input rank. - if ((int64_t)optionalIndicesList.size() > inputRank) + if (optionalIndicesCount > inputRank) return rewriter.notifyMatchFailure( op, "Indices list size should not be greater than the input rank."); - Value torchCstNone = rewriter.create(loc); - unsigned sizeOptionalIndicesList = optionalIndicesList.size(); - SmallVector nonNoneIndexTensorDim; - unsigned numNonNoneIndices; - - if (sizeOptionalIndicesList == 0) + if (optionalIndicesCount == 0) return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); - for (unsigned i = 0; i < optionalIndicesList.size(); i++) { - if (!optionalIndicesList[i].getType().isa()) { - nonNoneIndexTensorDim.push_back(i); - } - } - - numNonNoneIndices = nonNoneIndexTensorDim.size(); - if (numNonNoneIndices > 2) { - return rewriter.notifyMatchFailure( - op, "unimplemented: non none index tensors less than or equal to 2 " - "supported only"); - } else if (numNonNoneIndices == 2 && - nonNoneIndexTensorDim[0] != nonNoneIndexTensorDim[1] - 1) { - return rewriter.notifyMatchFailure( - op, "unimplemented: case of 2 non none index tensors is supported " - "only when both the tensors are along consecutive dimensions"); - } - - // Padding the indices list with none values. - if (sizeOptionalIndicesList < inputRank) { - for (unsigned i = 0; i < (inputRank - sizeOptionalIndicesList); i++) - optionalIndicesList.push_back(torchCstNone); + // Filter to available indices and get the indicesMap: + SmallVector indicesList; + SmallVector indicesMap; + int64_t numBatchDims = 0; + for (int i = 0, s = optionalIndicesList.size(); i < s; ++i) { + if (isa(optionalIndicesList[i].getType())) + continue; + indicesList.push_back(optionalIndicesList[i]); + indicesMap.push_back(i); + + auto indexTy = cast(indicesList.back().getType()); + numBatchDims = std::max(static_cast(indexTy.getSizes().size()), + numBatchDims); } - SmallVector indexBroadcastShapeInt{ - optionalIndicesList[nonNoneIndexTensorDim[0]] - .getType() - .cast() - .getSizes()}; - SmallVector indexBroadcastShapeValue; - if (numNonNoneIndices == 2) { - computeBroadcastShape(rewriter, loc, - optionalIndicesList[nonNoneIndexTensorDim[0]], - optionalIndicesList[nonNoneIndexTensorDim[1]], - indexBroadcastShapeInt, indexBroadcastShapeValue); - } else { - // It means there's only one index tensor and broadcast shape is same as - // that index tensor' shape. - for (unsigned i = 0; i < indexBroadcastShapeInt.size(); i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - indexBroadcastShapeValue.push_back(rewriter.createOrFold( - loc, optionalIndicesList[nonNoneIndexTensorDim[0]], dim)); + // Value broadcasting semantics require batch dimensions to be up front if + // the indices are not sequential, otherwise they are sequentially at their + // location: + int64_t batchDim = 0; + for (int s = optionalIndicesList.size(); batchDim < s; ++batchDim) + if (!isa(optionalIndicesList[batchDim].getType())) + break; + + int64_t nextNone = batchDim; + for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone) + if (isa(optionalIndicesList[nextNone].getType())) + break; + + for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone) + if (!isa(optionalIndicesList[nextNone].getType())) + batchDim = 0; + + // Indices are extended, catted, and collapsed into a [batch, depth] tensor: + Value indices = combinePutIndices(loc, indicesList, rewriter); + + // Bove batch dimensions to the front and collapse into a single dim: + values = + collapseAndMoveBatchDims(loc, values, batchDim, numBatchDims, rewriter); + valuesType = cast(values.getType()); + + // Materialize out the length-1 dimensions: + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + llvm::SmallVector valuesShape{valuesType.getSizes().front()}; + llvm::SmallVector valuesDims; + valuesDims.push_back( + rewriter.create(loc, values, zero)); + + int vDim = 1; + for (int i = 0, s = inputType.getSizes().size(); i < s; ++i) { + if (i < optionalIndicesCount && + !isa(optionalIndicesList[i].getType())) { + valuesDims.push_back(one); + valuesShape.push_back(1); + continue; } - } - Type indicesDtype = optionalIndicesList[nonNoneIndexTensorDim[0]] - .getType() - .cast() - .getDtype(); - - // This implementation is done to get the scatter indices: - - // def get_broadcast_shape(tensors): - // return list(torch.broadcast_tensors(*tensors)[0].shape) - - // def get_input_shape_to_output_shape_map(optional_index_tensors: - // list[Optional[torch.Tensor]]): - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) broadcast_rank = - // len(get_broadcast_shape(index_tensors)) num_of_index_tensors = - // len(index_tensors) index_of_first_index_tensor: Optional[int] = None - // result = {} - // for i, index in enumerate(optional_index_tensors): - // if index is None: - // val = i - // if index_of_first_index_tensor is not None: - // val += broadcast_rank - num_of_index_tensors - // result[i] = [val] - // else: - // if index_of_first_index_tensor is None: - // index_of_first_index_tensor = i - // output_indices = list(range(index_of_first_index_tensor, - // index_of_first_index_tensor + - // broadcast_rank)) - // result[i] = output_indices - // return result - - // def spotlight_indices(shape, indices: list[int]): - // """Turn all values in `shape` to `1` except for those with index in - // `indices`.""" shape = shape.copy() for i in range(len(shape)): - // if i not in indices: - // shape[i] = 1 - // return shape - - // def get_final_shape(input, optional_index_tensors: - // list[Optional[torch.Tensor]]): - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) index_tensors_broadcast_shape = - // get_broadcast_shape(index_tensors) result = [] - // handled_index_tensor_space = False - // for e, i in enumerate(input.shape): - // if optional_index_tensors[e] is None: - // result.append(i) - // else: - // if not handled_index_tensor_space: - // handled_index_tensor_space = True - // result += index_tensors_broadcast_shape - // return result - - // def get_scatter_indices(input, optional_index_tensors: - // list[Optional[torch.Tensor]]): - // assert len(input.size()) == len(optional_index_tensors), "Pad indices - // with None" shape_map = - // get_input_shape_to_output_shape_map(optional_index_tensors) - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) index_tensors_broadcast_shape = - // get_broadcast_shape(index_tensors) final_shape = - // get_final_shape(input, optional_index_tensors) - - // broadcasted_index_tensors = [] - // for e, optional_index_tensor in enumerate(optional_index_tensors): - // if optional_index_tensor is None: - // tensor_to_broadcast = torch.arange(0, input.size(e)) - // else: - // tensor_to_broadcast = - // optional_index_tensor.broadcast_to(index_tensors_broadcast_shape) - - // broadcasted_index_tensor = \ - // tensor_to_broadcast.reshape(spotlight_indices(final_shape, shape_map[e]))\ - // .broadcast_to(final_shape)\ - // .flatten() - // broadcasted_index_tensors.append(broadcasted_index_tensor) - - // return torch.stack(broadcasted_index_tensors, dim=0).t() - - auto scatterIndicesInfo = - getScatterIndices(op, rewriter, indicesDtype, optionalIndicesList, - indexBroadcastShapeInt, indexBroadcastShapeValue); - if (failed(scatterIndicesInfo)) { - return rewriter.notifyMatchFailure( - op, "cannot generate scatter indices for index put op"); + Value k = rewriter.create( + loc, rewriter.getI64IntegerAttr(vDim)); + valuesDims.push_back( + rewriter.create(loc, values, k)); + valuesShape.push_back(inputType.getSizes()[i]); + vDim++; } - Value indexTensor = *scatterIndicesInfo; - // Flattening the values tensor. - Value torchCstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value flattenedValuesTensorLastDim = rewriter.create( - loc, - rewriter.getI64IntegerAttr(valuesTensorType.getSizes().size() - 1)); - SmallVector valuesShapeInt{valuesTensorType.getSizes()}; - int64_t valuesCount = 1; - if (llvm::all_of(valuesShapeInt, - [](int64_t shape) { return shape != kUnknownSize; })) { - for (int64_t i : valuesShapeInt) - valuesCount *= i; - } else { - valuesCount = kUnknownSize; - } - auto flattenedValuesTensorType = ValueTensorType::get( - context, llvm::ArrayRef(valuesCount), valuesTensorType.getDtype()); - Value flattenedValuesTensor = rewriter.create( - loc, flattenedValuesTensorType, op.getValues(), torchCstZero, - flattenedValuesTensorLastDim); - values = typeConverter->materializeTargetConversion( - rewriter, loc, - typeConverter->convertType(flattenedValuesTensor.getType()), - flattenedValuesTensor); + Value valuesDimsList = rewriter.create( + loc, Torch::ListType::get(rewriter.getType()), + valuesDims); + + valuesType = rewriter.getType( + valuesShape, valuesType.getOptionalDtype()); + values = + rewriter.create(loc, valuesType, values, valuesDimsList); // `TMTensor::ScatterOp` expects indices of element type i32. - Value indices = convertTensorToDtype( - rewriter, loc, indexTensor, + indices = convertTensorToDtype( + rewriter, loc, indices, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); + + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(input.getType()), input); + values = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(values.getType()), values); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); @@ -931,7 +832,8 @@ class ConvertAten_IndexPutImplOp // 3.) `input` is mapped to `original` in scatter op. bool invalidInputTypeFound = false; Value scatterOp = createTMTensorScatterOp( - rewriter, loc, values, indices, input, /*uniqueIndices=*/false, + rewriter, loc, values, indices, input, indicesMap, + /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; @@ -1064,7 +966,8 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp rewriter.getAffineDimExpr(tensorOperandRank)); SmallVector indexingMaps = AffineMap::inferFromExprList( - {originalIndicesDimExprs, updatedIndicesDimExprs}); + {originalIndicesDimExprs, updatedIndicesDimExprs}, + rewriter.getContext()); SmallVector iteratorTypes( tensorOperandRank + 1, utils::IteratorType::parallel); @@ -1149,6 +1052,7 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp Value scatterOp = createTMTensorScatterOp( rewriter, loc, /*updates=*/gradOutputFlattened, /*indices=*/indicesCollapsed, /*original=*/outputTensor, + /*dimensionsMap=*/createDefaultDimMap(indicesCollapsed), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { @@ -1291,6 +1195,7 @@ class ConvertAtenScatterReduceTwoOp srcType.getElementType(), /*init_element=*/normalizationValue); self = createTMTensorScatterOp( rewriter, loc, normalizations, indices, self, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { b.create(loc, update); @@ -1298,6 +1203,7 @@ class ConvertAtenScatterReduceTwoOp if (reduceEnum == torch_upstream::ReductionType::MEAN) { counts = createTMTensorScatterOp( rewriter, loc, normalizations, indices, counts, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { b.create(loc, update); @@ -1308,7 +1214,7 @@ class ConvertAtenScatterReduceTwoOp // Create final operation Value scatterOp = createTMTensorScatterOp( rewriter, loc, updates, indices, self, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { Value result; if (reduceEnum == torch_upstream::ReductionType::SUM || @@ -1352,6 +1258,7 @@ class ConvertAtenScatterReduceTwoOp if (reduceEnum == torch_upstream::ReductionType::MEAN) { counts = createTMTensorScatterOp( rewriter, loc, updates, indices, counts, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { Value result; @@ -1599,27 +1506,82 @@ class ConvertAtenScaledDotProductAttentionOp "only default scale supported"); } + auto opTy = cast(op.getType()).toBuiltinTensor(); + auto query = adaptor.getQuery(); + auto value = adaptor.getValue(); + auto key = adaptor.getKey(); + auto queryTy = cast(query.getType()); + auto valueTy = cast(value.getType()); + auto keyTy = cast(key.getType()); + + if (queryTy.getRank() != valueTy.getRank() || + queryTy.getRank() != keyTy.getRank()) + return rewriter.notifyMatchFailure(op, "operand ranks do not match"); + + if (queryTy.getRank() < 3) + return rewriter.notifyMatchFailure(op, "missing batch dimension"); + + llvm::SmallVector reassociation(3); + for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) + reassociation.front().push_back(i); + reassociation[1].push_back(valueTy.getRank() - 2); + reassociation[2].push_back(valueTy.getRank() - 1); + + auto loc = op.getLoc(); + auto collapseBatch = [&rewriter, &reassociation, + loc](Value value) -> Value { + auto valueTy = cast(value.getType()); + if (valueTy.getRank() == 3) + return value; + + llvm::SmallVector newShape(3, 1); + newShape[1] = valueTy.getDimSize(valueTy.getRank() - 2); + newShape[2] = valueTy.getDimSize(valueTy.getRank() - 1); + + for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) { + if (valueTy.isDynamicDim(i)) { + newShape[0] = ShapedType::kDynamic; + break; + } + newShape[0] = newShape[0] * valueTy.getDimSize(i); + } + + auto collapseTy = valueTy.clone(newShape); + return rewriter.create(loc, collapseTy, value, + reassociation); + }; + + query = collapseBatch(query); + key = collapseBatch(key); + value = collapseBatch(value); + SmallVector outSizes( - adaptor.getQuery().getType().cast().getShape()); + query.getType().cast().getShape()); SmallVector valueSizes( - adaptor.getValue().getType().cast().getShape()); + value.getType().cast().getShape()); outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1]; SmallVector outSizesDynamic( - getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery())); - outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes( - rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1]; + getTensorSizes(rewriter, op.getLoc(), query)); + outSizesDynamic[outSizesDynamic.size() - 1] = + getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1]; Type outType = RankedTensorType::get(outSizes, elementType); Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic, elementType); // Overwrite with tm_tensor::attention - auto attention = rewriter.create( - op.getLoc(), outType, - SmallVector{adaptor.getQuery(), adaptor.getKey(), - adaptor.getValue()}, - SmallVector{output}); + Value attention = + rewriter + .create(loc, outType, + SmallVector{query, key, value}, + SmallVector{output}) + .getResult()[0]; + + if (opTy != outType) { + attention = rewriter.create(loc, opTy, attention, + reassociation); + } - rewriter.replaceOp(op, attention.getResult()); + rewriter.replaceOp(op, attention); return success(); } diff --git a/lib/Conversion/TorchToTensor/CMakeLists.txt b/lib/Conversion/TorchToTensor/CMakeLists.txt new file mode 100644 index 000000000000..21082d1d1258 --- /dev/null +++ b/lib/Conversion/TorchToTensor/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(TorchMLIRTorchToTensor + TorchToTensor.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTensor + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTensorDialect + TorchMLIRTorchDialect + TorchMLIRConversionUtils +) + +torch_mlir_target_includes(TorchMLIRTorchToTensor) diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp new file mode 100644 index 000000000000..8b934ccb0484 --- /dev/null +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -0,0 +1,142 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v3.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-1.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" + +#include "../PassDetail.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +class ConvertAtenItemOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenItemOp::Adaptor; + LogicalResult + matchAndRewrite(AtenItemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operand = adaptor.getOperands()[0]; + auto operandTy = cast(operand.getType()); + auto torchDTy = cast(op.getOperand().getType()).getDtype(); + + if (operandTy.getNumElements() != 1) + return rewriter.notifyMatchFailure(op, "expected only one item"); + + auto zeroIdx = rewriter.create(op.getLoc(), 0); + auto rank = operandTy.getRank(); + llvm::SmallVector indices(rank, zeroIdx); + + Value extract = rewriter.create( + op.getLoc(), operandTy.getElementType(), operand, indices); + auto extractTy = extract.getType(); + if (isa(extractTy) && !extractTy.isInteger(64)) { + if (torchDTy.isSignlessInteger()) { + extract = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), extract); + } else { + extract = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), extract); + } + } + + if (isa(extractTy) && !extractTy.isF64()) { + extract = rewriter.create(op.getLoc(), + rewriter.getF64Type(), extract); + } + + rewriter.replaceOp(op, extract); + return success(); + } +}; + +class ConvertAtenShapeToTensorPatternOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Aten_ShapeAsTensorOp::Adaptor; + LogicalResult + matchAndRewrite(Aten_ShapeAsTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto operand = adaptor.getOperands()[0]; + auto operandTy = operand.getType().cast(); + auto resultTy = + getTypeConverter()->convertType(op.getType()).cast(); + + int64_t rank = operandTy.getRank(); + if (rank == 0) { + rewriter.replaceOpWithNewOp(op, resultTy.getShape(), + resultTy.getElementType()); + return success(); + } + + SmallVector dims; + for (int i = 0; i < rank; ++i) { + Value dim = rewriter.createOrFold(loc, operand, i); + dim = rewriter.createOrFold( + loc, resultTy.getElementType(), dim); + dims.push_back(dim); + } + + Value tensor = + rewriter.createOrFold(op.getLoc(), dims); + rewriter.replaceOp(op, tensor); + return success(); + } +}; + +class ConvertTorchToTensor + : public ConvertTorchToTensorBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + patterns.add( + typeConverter, context); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchToTensorPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 779bd6249283..f174da4f43b7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,27 +8,24 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" - #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" -#include "llvm/ADT/SmallVector.h" #include +#include using namespace mlir; using namespace mlir::torch; @@ -1327,9 +1324,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // increasing. E.g. [0, 1, 2, 3]: No transpose [1, 0, 2, 3]: Transpose dim0 // and dim1 The order need not be sequential, since one or more dims may // have been removed due to broadcasting. - auto isTransposeRequired = [](ArrayRef transposedDims) -> bool { + auto isTransposeRequired = [](SmallVector transposedDims) -> bool { int32_t lastDim = -1; - for (auto dim : transposedDims) { + for (auto &dim : transposedDims) { if (lastDim > dim) return true; lastDim = dim; @@ -1337,7 +1334,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return false; }; - SmallVector commonElems, lhsSqueezedElems, rhsSqueezedElems; + SmallVector batchElems, lhsSqueezedElems, rhsSqueezedElems; if (!performBatchDimBroadcast) { // Simple with no broadcasting artifacts. Just reshape up to 3D @@ -1391,7 +1388,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { if (isDynamicDim || lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { commonValue *= lhsBroadcastedShape[dim]; - commonElems.push_back({dim, lhsBroadcastedShape[dim]}); + batchElems.push_back({dim, lhsBroadcastedShape[dim]}); } } commonValue = commonValue < 0 ? kUnknownSize : commonValue; @@ -1418,9 +1415,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Step: Create the tosa.transpose array. If this array has a // non-monotonic series of dims, perform transpose. // First the common_elems - for (uint32_t i = 0; i < commonElems.size(); i++) { - transposedLhsShape.push_back(commonElems[i].shape); - transposedLhsDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + transposedLhsShape.push_back(batchElems[i].shape); + transposedLhsDims.push_back(batchElems[i].dim); } // then the lhs_squeezed elems for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) { @@ -1476,9 +1473,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Step: Create the RHS transpose sequence // RHS = {common, matmul_dim, rhs_squeezed} // first the common_dims - for (uint32_t i = 0; i < commonElems.size(); i++) { - transposedRhsShape.push_back(commonElems[i].shape); - transposedRhsDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + transposedRhsShape.push_back(batchElems[i].shape); + transposedRhsDims.push_back(batchElems[i].dim); } // The matmul_dim of RHS transposedRhsDims.push_back(maxInputRank - 2); @@ -1556,8 +1553,15 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { SmallVector matmulOutputShape( {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); + Type outputElemTy; + if (lhsElemTy.isa()) { + outputElemTy = lhsElemTy; + } else { // qint8 emits i32 matmul output + outputElemTy = rewriter.getIntegerType(32); + } + auto mmOutputTy = RankedTensorType::get( - makeShapeLLVMCompatible(matmulOutputShape), outputElemType); + makeShapeLLVMCompatible(matmulOutputShape), outputElemTy); auto mmOpResult = rewriter .create( @@ -1567,14 +1571,6 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { matmulLhs, matmulRhs) .getResult(); - auto originalMatMulInputType = lhsElemTy; - auto castOpResult = - rewriter - .createOrFold(op->getLoc(), - cast(mmOpResult.getType()) - .clone(originalMatMulInputType), - mmOpResult); - // Perform the reshape to output shape. This is always required unless max // input rank=3 and there was no broadcasting, in which case the tosa.matmul // output itself is correctly shaped. @@ -1586,6 +1582,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // an unknown to-be-inferred output shape. The final tensor.cast // reshapes the known shape to the desired output shape. auto computeOpShape = [&](SmallVector &reshapedOpShape, + SmallVector &transposedOpDims, SmallVector &transposedOpShapes) { if (maxInputRank == 1) return; @@ -1600,8 +1597,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Step: Construct the output transpose/reshape information // First the common_dims - for (uint32_t i = 0; i < commonElems.size(); i++) { - reshapedOpShape.push_back(commonElems[i].shape); + for (uint32_t i = 0; i < batchElems.size(); i++) { + reshapedOpShape.push_back(batchElems[i].shape); + transposedOpDims.push_back(batchElems[i].dim); } // Then the LHS squeezed dims @@ -1610,12 +1608,14 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // other input. if (lhsSqueezedElems[i].shape != 1) { reshapedOpShape.push_back(lhsSqueezedElems[i].shape); + transposedOpDims.push_back(lhsSqueezedElems[i].dim); } } // The last squeezed dim is lhs[-2] which needs to be // checked separately for broadcasting if (lhsRank > 1) { reshapedOpShape.push_back(lhsBroadcastedShape[maxInputRank - 2]); + transposedOpDims.push_back(maxInputRank - 2); } // then the RHS squeezed dims except rhs[-1] which is handled like @@ -1623,13 +1623,23 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { for (uint32_t i = 0; i < rhsSqueezedElems.size() - 1; i++) { if (rhsSqueezedElems[i].shape != 1) { reshapedOpShape.push_back(rhsSqueezedElems[i].shape); + transposedOpDims.push_back(rhsSqueezedElems[i].dim); } } // rhs[-1] if (rhsRank > 1) { reshapedOpShape.push_back(rhsBroadcastedShape[maxInputRank - 1]); + transposedOpDims.push_back(maxInputRank - 1); } + // The transposition order is the inverse of what we actually want, + // inversing should fix this: + llvm::SmallVector inverseTransposeDims(transposedOpDims.size()); + for (int i = 0, s = transposedOpDims.size(); i < s; ++i) + inverseTransposeDims[transposedOpDims[i]] = i; + + transposedOpDims = inverseTransposeDims; + // Final transposed output shape construction for (uint32_t i = 0; i < maxInputRank - 2; i++) { if (lhsBroadcastedTy.isDynamicDim(i)) { @@ -1652,43 +1662,32 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return; }; - // Calculated output shapes for reshape and transpose - SmallVector reshapedOpShape; - SmallVector transposedOpShape; - computeOpShape(reshapedOpShape, transposedOpShape); + SmallVector reshapedOpShape, transposedOpShape; + SmallVector transposedOpDims; + + computeOpShape(reshapedOpShape, transposedOpDims, transposedOpShape); + + bool opNeedsTranspose = isTransposeRequired(transposedOpDims); // Perform reshape auto reshapedOpType = RankedTensorType::get( - makeShapeLLVMCompatible(reshapedOpShape), originalMatMulInputType); + makeShapeLLVMCompatible(reshapedOpShape), outputElemTy); auto reshapedOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( reshapedOpType), - castOpResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); - - // Calculate transmutation required - SetVector transmutationSetVec; - for (unsigned i = 0; i < transposedOpShape.size(); i++) { - for (unsigned j = 0; j < reshapedOpShape.size(); j++) { - if (!transmutationSetVec.contains(j) && - transposedOpShape[i] == reshapedOpShape[j]) { - transmutationSetVec.insert(j); - break; - } - } - } - ArrayRef transVec = transmutationSetVec.getArrayRef(); + mmOpResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); + + if (opNeedsTranspose) { - if (isTransposeRequired(transVec)) { std::optional transposedOpShapeConst = tosa::getConstTensor( rewriter, op, - /*vec=*/transVec, - /*shape=*/{static_cast(transVec.size())}); + /*vec=*/transposedOpDims, + /*shape=*/{static_cast(transposedOpDims.size())}); - auto transposedOpType = - RankedTensorType::get(makeShapeLLVMCompatible(transposedOpShape), - originalMatMulInputType); + auto transposedOpType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedOpShape), outputElemTy); output = rewriter .create( op->getLoc(), @@ -1701,7 +1700,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { output = reshapedOp.getResult(); } } else { - output = castOpResult; + output = mmOpResult; } return success(); @@ -2010,6 +2009,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: non-constant value for transposed not supported"); + if (transposed) + return rewriter.notifyMatchFailure( + op, "Unimplemented: transposed convolution not supported"); + auto input = adaptor.getInput(); auto weight = adaptor.getWeight(); @@ -2078,15 +2085,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); - bool transposed; - if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) - return rewriter.notifyMatchFailure( - op, "transpose must be a bool constant"); - - if (transposed) - return rewriter.notifyMatchFailure( - op, "Unimplemented: only non-transposed convolutions supported"); - // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. @@ -3493,6 +3491,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if(op->use_empty()) { + rewriter.eraseOp(op); + return success(); + } + auto selfType = adaptor.getSelf().getType().dyn_cast(); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -3524,12 +3527,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); - // support for start < 0 - start = toPositiveDim(start, sizeOfDim); + start = toPositiveDim(start, selfType.getShape()[dim]); start = std::clamp(start, (int64_t)0, sizeOfDim); - start = std::min(selfType.getShape()[dim], start); - int64_t end; if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { if (isa(op.getEnd().getDefiningOp())) @@ -4509,59 +4509,51 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); - int64_t intMin = 0; - int64_t intMax = 0; - double fpMin = 0.0; - double fpMax = 0.0; - - auto min = op.getMin(); - auto isIntMin = matchPattern(min, m_TorchConstantInt(&intMin)); - auto isFloatMin = matchPattern(min, m_TorchConstantFloat(&fpMin)); - auto isNoneTypeMin = min.getType().isa(); - - auto max = op.getMax(); - auto isIntMax = matchPattern(max, m_TorchConstantInt(&intMax)); - auto isFloatMax = matchPattern(max, m_TorchConstantFloat(&fpMax)); - auto isNoneTypeMax = max.getType().isa(); - - if (!(isIntMin || isFloatMin || isNoneTypeMin)) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_min` should be a torch constant " - "int/float or Torch::NoneType"); + IntegerAttr min_int = + rewriter.getI64IntegerAttr(std::numeric_limits::min()); + IntegerAttr max_int = + rewriter.getI64IntegerAttr(std::numeric_limits::max()); + FloatAttr min_fp = + rewriter.getF32FloatAttr(std::numeric_limits::lowest()); + FloatAttr max_fp = + rewriter.getF32FloatAttr(std::numeric_limits::max()); + + auto getValAttr = [&](Value operand, IntegerAttr &intAttr, + FloatAttr &fpAttr) -> LogicalResult { + double valFloat; + int64_t valInt; + if (matchPattern(operand, m_TorchConstantFloat(&valFloat))) { + intAttr = rewriter.getI64IntegerAttr(static_cast(valFloat)); + fpAttr = rewriter.getF32FloatAttr(static_cast(valFloat)); + } else if (matchPattern(operand, m_TorchConstantInt(&valInt))) { + intAttr = rewriter.getI64IntegerAttr(valInt); + fpAttr = rewriter.getF32FloatAttr(static_cast(valInt)); + } else { + return failure(); + } + return success(); + }; - if (!(isIntMax || isFloatMax || isNoneTypeMax)) + LogicalResult minAttrResult = getValAttr(op.getMin(), min_int, min_fp); + LogicalResult maxAttrResult = getValAttr(op.getMax(), max_int, max_fp); + if (failed(minAttrResult) && failed(maxAttrResult)) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_max` should be a torch constant " - "int/float or Torch::NoneType"); - - // Adjust min and max to their numeric_limits if type == Torch::NoneType. - if (isNoneTypeMin) { - intMin = std::numeric_limits::min(); - fpMin = std::numeric_limits::lowest(); + op, "either `min` or `max` should be a torch constant"); } - if (isNoneTypeMax) { - intMax = std::numeric_limits::max(); - fpMax = std::numeric_limits::max(); + if (failed(minAttrResult) && + succeeded(checkNotNone(rewriter, op, op.getMin()))) { + return rewriter.notifyMatchFailure(op, + "min attr should be a torch constant"); + } + if (failed(maxAttrResult) && + succeeded(checkNotNone(rewriter, op, op.getMax()))) { + return rewriter.notifyMatchFailure(op, + "max attr should be a torch constant"); } - - // If we are using integer for min and max values, - // import them from their fp counterparts. - if (isIntMin) - fpMin = static_cast(intMin); - - if (isIntMax) - fpMax = static_cast(intMax); auto outType = getTypeConverter()->convertType(op.getType()); - - // It is safe to static_cast to float since tosa doesn't support fp64. - FloatAttr minFp = rewriter.getF32FloatAttr(static_cast(fpMin)); - FloatAttr maxFp = rewriter.getF32FloatAttr(static_cast(fpMax)); - IntegerAttr minInt = rewriter.getI64IntegerAttr(intMin); - IntegerAttr maxInt = rewriter.getI64IntegerAttr(intMax); - rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf(), - minInt, maxInt, minFp, maxFp); + min_int, max_int, min_fp, max_fp); return success(); } @@ -4589,59 +4581,138 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - auto matchIntOrDouble = - [&](Value val) -> std::tuple { - // Match int or fp values. The one used depends on the resultType. - // Therefore `valueInt` and `valueDouble` will have similar values (but may - // be truncated due to casting). - int64_t valueInt = 0; - double valueDouble = 0.0; - if (matchPattern(val, m_TorchConstantInt(&valueInt))) - return {success(), valueInt, static_cast(valueInt)}; - if (matchPattern(val, m_TorchConstantFloat(&valueDouble))) - return {success(), static_cast(valueDouble), valueDouble}; - return {failure(), valueInt, valueDouble}; + // Stores a range value (a start, end, or step value) and whether or not it + // was initiated with a constant integer, an constant float or neither. + class ConstRangeValue { + public: + explicit ConstRangeValue(double v) + : vDouble(v), fromDouble(true), vInt(static_cast(v)), + fromInt(false) {} + + explicit ConstRangeValue(int64_t v) + : vDouble(static_cast(v)), fromDouble(false), vInt(v), + fromInt(true) {} + + // Constructor for the case where there is no constant value to use. + ConstRangeValue() + : vDouble(0), fromDouble(false), vInt(0), fromInt(false) {} + + static ConstRangeValue fromValue(Value v) { + int64_t intVal{0}; + double floatVal{0.0}; + if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { + return ConstRangeValue(floatVal); + } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { + return ConstRangeValue(intVal); + } + return ConstRangeValue(); + } + + bool hasConstInt() const { return fromInt; } + bool hasConstDouble() const { return fromDouble; } + bool hasConst() const { return fromInt || fromDouble; } + double getDouble() const { return vDouble; } + int64_t getInt() const { return vInt; } + + private: + double vDouble; + bool fromDouble; + int64_t vInt; + bool fromInt; }; - auto [matchStart, startInt, startDouble] = matchIntOrDouble(op.getStart()); - if (failed(matchStart)) + auto start = ConstRangeValue::fromValue(op.getStart()); + if (!start.hasConst()) { return rewriter.notifyMatchFailure( - op, - "unimplemented: value `start` should be a torch constant int or float"); + op, "unimplemented: case where `start` is not a constant int or float"); + } - auto [matchEnd, endInt, endDouble] = matchIntOrDouble(op.getEnd()); - if (failed(matchEnd)) + auto end = ConstRangeValue::fromValue(op.getEnd()); + if (!end.hasConst()) { return rewriter.notifyMatchFailure( op, - "unimplemented: value `end` should be a torch constant int or float"); + "unimplemented: case where value `end` is not a constant int or float"); + } - auto [matchStep, stepInt, stepDouble] = matchIntOrDouble(op.getStep()); - if (failed(matchStep)) - return rewriter.notifyMatchFailure( - op, - "unimplemented: value `step` should be a torch constant int or float"); + auto step = ConstRangeValue::fromValue(op.getStep()); + if (!step.hasConst()) { + return rewriter.notifyMatchFailure(op, + "unimplemented: case where value `step` " + "is not a constant int or float"); + } - // The result will always be a 1-d tensor. - // The size of the result is calculated as follows: - // ceil((end - start)/step) - auto elementType = resultType.getElementType(); - Value result; - if (isa(elementType)) { - int64_t resultShape = ceil(static_cast(endInt - startInt) / - static_cast(stepInt)); - SmallVector values(resultShape, startInt); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * stepInt; - result = tosa::getConstTensor(rewriter, op, values, resultShape) - .value(); - } else { - int64_t resultShape = ceil((endDouble - startDouble) / stepDouble); - SmallVector values(resultShape, startDouble); - for (unsigned i = 1; i < resultShape; i++) - values[i] += static_cast(i) * stepDouble; - result = tosa::getConstTensor(rewriter, op, values, resultShape) - .value(); + auto getRange = [](auto start, auto end, auto step) { + // Initialize a small vector of the same type as start: + using T = decltype(start); + SmallVector values; + + uint64_t counter{0}; + if (start == end) { + return values; + } + assert(step != T(0)); + values.reserve( + 1 + static_cast(std::abs((end - start) / std::abs(step)))); + if (step > 0) { + while (start + T(counter) * step < end) { + values.push_back(start + counter * step); + counter++; + } + } else { + while (start + T(counter) * step > end) { + values.push_back(start + counter * step); + counter++; + } + } + return values; + }; + + const auto isIntType = + resultType.getElementType().dyn_cast_or_null(); + + const auto isDoubleType = + resultType.getElementType().dyn_cast_or_null(); + + auto maybeResult = [&]() -> std::optional { + // Integer output type, and start / end / range are all integers. + if (isIntType && start.hasConstInt() && end.hasConstInt() && + step.hasConstInt()) { + auto values = getRange(start.getInt(), end.getInt(), step.getInt()); + return tosa::getConstTensor(rewriter, op, values, values.size()); + } + + // Get a double range. + auto values = + getRange(start.getDouble(), end.getDouble(), step.getDouble()); + if (isIntType) { + SmallVector values_i64; + values_i64.reserve(values.size()); + for (auto v : values) { + values_i64.push_back(static_cast(v)); + } + return tosa::getConstTensor(rewriter, op, values_i64, + values.size()); + } + + if (!isDoubleType) { + return {}; + } + + SmallVector values_f32; + values_f32.reserve(values.size()); + for (auto v : values) { + values_f32.push_back(static_cast(v)); + } + auto vs = tosa::getConstTensor(rewriter, op, values_f32, + values_f32.size()); + return vs; + }(); + + if (!maybeResult.has_value()) { + return rewriter.notifyMatchFailure( + op, "failed to generate constant tensor for arange"); } + auto result = maybeResult.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); @@ -5298,9 +5369,9 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Supplied value must be a Scalar constant"); - auto newOp = - rewriter.createOrFold(op.getLoc(), outType, constOp); + auto newOp = rewriter.createOrFold(op.getLoc(), outType, constOp); rewriter.replaceOp(op, newOp); + return success(); } }; @@ -5981,6 +6052,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { mlir::tosa::convertReduceMeanOp) INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index e16bd0cb507e..b4e82360c60f 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include #include @@ -130,10 +131,10 @@ tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, } std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, - Operation *op, - Value paramsValue, - Value indexValue, - int32_t axis) { + Operation *op, + Value paramsValue, + Value indexValue, + int32_t axis) { // For easy understanding of this algorithm, the following comments are with // an exact example: torch.aten.gather(!torch.vtensor<[1,4,3],f32>, axis=2, // !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> @@ -209,9 +210,9 @@ std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, // Lowers Gather operators to a sequence of TOSA ops. // taken from // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc -std::optional convertGatherNdOp(PatternRewriter &rewriter, - Operation *op, Type outType, - Value paramsValue, Value indicesValue) { +std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, + Type outType, Value paramsValue, + Value indicesValue) { auto resultType = outType.dyn_cast(); auto paramsType = paramsValue.getType().dyn_cast(); auto indicesType = indicesValue.getType().dyn_cast(); @@ -682,7 +683,6 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, .getResult(); } - // Common function for lowering reduce operations to TOSA ops. template std::optional convertReduceOpCommon( @@ -727,9 +727,8 @@ std::optional convertReduceOpCommon( auto axis_attr = rewriter.getI32IntegerAttr(axis_val); shape_vec[axis_val] = 1; - RankedTensorType reduce_type = RankedTensorType::get( - shape_vec, - reduce_element_type); + RankedTensorType reduce_type = + RankedTensorType::get(shape_vec, reduce_element_type); auto reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, val, axis_attr); @@ -978,5 +977,75 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, return val; } +// Lowers LinalgVectorNorm to a sequence of TOSA ops. +std::optional +convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, + RankedTensorType output_type, Value input_value, + ElementsAttr axes_elems, bool keep_dims) { + RankedTensorType input_type = + input_value.getType().dyn_cast(); + if (!input_type) + return std::nullopt; + + Type elemType = output_type.getElementType(); + if (!elemType.isa()) { + op->emitOpError("Only floating-point datatype legalization supported for " + "AtenLinalgVectorNorm op"); + return std::nullopt; + } + + auto linalgVectorNormOp = cast(op); + // TODO: Add support for ord = {0, +inf, -inf}. + auto epsilon = 1e-5; + double ordLiteralFloat = 1.0; + int64_t ordLiteralInt = 1; + Value ordVal; + if (matchPattern(linalgVectorNormOp.getOrd(), + torch::Torch::m_TorchConstantFloat(&ordLiteralFloat))) { + ordVal = tosa::getConstTensor(rewriter, op, + {static_cast(ordLiteralFloat)}, + {}, elemType) + .value(); + } else if (matchPattern(linalgVectorNormOp.getOrd(), + torch::Torch::m_TorchConstantInt(&ordLiteralInt))) { + ordVal = tosa::getConstTensor(rewriter, op, + {static_cast(ordLiteralInt)}, + {}, elemType) + .value(); + } else { + op->emitOpError("only support FP or INT type ord parameter"); + return std::nullopt; + } + + if (fabs(ordLiteralFloat) < epsilon || + fabs(static_cast(ordLiteralInt)) < epsilon) { + op->emitOpError("unimplemented: L0 norm"); + return std::nullopt; + } + + if (std::isinf(ordLiteralFloat) || + std::isinf(static_cast(ordLiteralInt))) { + op->emitOpError("unimplemented: ord = +/- inf"); + return std::nullopt; + } + + auto absVal = CreateOpAndInfer(rewriter, op->getLoc(), + input_type, input_value) + .getResult(); + auto powVal = CreateOpAndInfer(rewriter, op->getLoc(), + input_type, absVal, ordVal) + .getResult(); + std::optional result = convertReduceSumOp( + rewriter, op, output_type, powVal, axes_elems, keep_dims); + if (!result) + return std::nullopt; + auto reciprocalVal = CreateOpAndInfer( + rewriter, op->getLoc(), ordVal.getType(), ordVal) + .getResult(); + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + result.value(), reciprocalVal) + .getResult(); +} + } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index ed7f6b2a9539..d2fe75390e68 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -33,8 +33,9 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op, rewriter.getI32IntegerAttr(static_cast(input_zp)), rewriter.getI32IntegerAttr(static_cast(output_zp)), rewriter.getDenseI32ArrayAttr({multiplier}), - rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), - rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false)); + rewriter.getDenseI8ArrayAttr({static_cast(shift)}), + rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), + rewriter.getBoolAttr(false)); return rescale_op.getResult(); } @@ -86,8 +87,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getDenseI32ArrayAttr({multiplier}), - rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), - rewriter.getBoolAttr(true), rewriter.getBoolAttr(false)); + rewriter.getDenseI8ArrayAttr({static_cast(shift)}), + rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true), + rewriter.getBoolAttr(false)); return rescale_op.getResult(); @@ -96,7 +98,7 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, .dyn_cast()) { // Per-channel quantization SmallVector multiplier_arr; - SmallVector shift_arr; + SmallVector shift_arr; SmallVector weight_scale_arr( weight_per_channel_qtype.getScales().begin(), @@ -115,14 +117,14 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, scale_width); multiplier_arr.push_back(multiplier); - shift_arr.push_back(shift); + shift_arr.push_back(static_cast(shift)); } auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getDenseI32ArrayAttr(multiplier_arr), - rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), + rewriter.getDenseI8ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true), rewriter.getBoolAttr(true)); return rescale_op.getResult(); @@ -186,7 +188,8 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // Default template creates a constant tensor in T. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape, std::optional dtype) { + ArrayRef vec, ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -198,7 +201,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, } auto width = sizeof(T) * 8; - if constexpr(std::is_same_v) + if constexpr (std::is_same_v) width = 1; auto const_type = @@ -209,7 +212,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -219,7 +222,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape, std::optional dtype) { + ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -237,7 +241,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -247,7 +251,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape, std::optional dtype) { + ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -292,7 +297,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -419,23 +424,17 @@ TypedValue transposeBy(Location loc, PatternRewriter &rewriter } // Template instantiation -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); - -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); - -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); + +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); + +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType) { diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 3df9da94b735..064215c51da0 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -245,12 +245,20 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, elementType, encoding); } +static std::optional getIntegerValue(Value scalar) { + if (auto constOp = scalar.getDefiningOp()) { + return std::optional(constOp.getValue()); + } + return std::optional(); +} + // Convert a scalar value to the target type. The scalar value can be an element // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype, - std::optional dstOriginalDtype) { + std::optional dstOriginalDtype, + std::optional originalScalar) { Type scalarType = scalar.getType(); if (scalarType == dtype) return scalar; @@ -262,7 +270,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return false; }; - // We don't support conversion to Byte dtype. + // We support conversion to Byte dtype only if the original scalar is an + // integer constant with value lying between 0 - 63. if (isByteOrChar(dtype)) { if (!dstOriginalDtype.has_value()) { mlir::emitError(loc) @@ -271,10 +280,22 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return nullptr; } if (dstOriginalDtype->isUnsignedInteger()) { - mlir::emitError(loc) - << "unsupported: conversion to byte type for convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype << "(dtype)"; - return nullptr; + if (originalScalar.has_value()) { + std::optional optConstVal = + getIntegerValue(originalScalar.value()); + if (optConstVal.has_value()) { + int64_t constVal = optConstVal.value(); + if (constVal < 0 || constVal > 63) { + // Do the conversion only if the original integer value is between + // 0 - 63. + mlir::emitError(loc) + << "unsupported: conversion to byte type for " + "convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; + } + } + } } } diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index dcb2f4215891..7b8a17682a9e 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -94,31 +94,22 @@ LogicalResult AttentionOp::verify() { ShapedType keyType = getKeyType(); ArrayRef queryShape = queryType.getShape(); ArrayRef keyShape = keyType.getShape(); - if (keyShape[0] != queryShape[0]) - return op->emitOpError("query and key batch mismatch"); - if (keyShape[2] != queryShape[2]) + for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { + if (keyShape[i] != queryShape[i]) + return op->emitOpError("query and key batch mismatch"); + } + if (keyShape.back() != queryShape.back()) return op->emitOpError("query and key head dimension mismatch"); return success(); } SmallVector AttentionOp::getIterationDomain(OpBuilder &builder) { - int64_t iterationDomainRank = getIterationDomainRank(); - SmallVector loopBounds(iterationDomainRank); - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = getQuery(); - for (auto dim : llvm::seq(0, iterationDomainRank)) { - loopBounds[dim].offset = zero; - loopBounds[dim].size = getDimValue(builder, loc, source, dim); - loopBounds[dim].stride = one; - } + SmallVector loopBounds; return loopBounds; } SmallVector AttentionOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getIterationDomainRank(), - utils::IteratorType::parallel); + SmallVector iteratorTypes; return iteratorTypes; } @@ -166,7 +157,6 @@ static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes, }) ->getResult(0); b.create(loc, sum, output, localIVs); - b.create(loc); }); } @@ -190,6 +180,8 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value zeroF = b.create(loc, elementType, b.getFloatAttr(elementType, 0.0)); + // TODO: This needs to be fixed, it assumes everything is dynamic however if + // any shapes are static the `memref.alloc` generated is illegal. SmallVector queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes; for (auto i = 0; i < queryRank; i++) queryDynSizes.push_back(b.create(loc, query, i)); @@ -205,9 +197,18 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, auto weightSizes = SmallVector(queryType.getShape()); weightSizes[weightRank - 1] = keySizes[keyRank - 2]; auto weightType = MemRefType::get(weightSizes, queryType.getElementType()); + + // Setup the weight dynamic sizes: SmallVector weightDynSizes(queryDynSizes); weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2]; - Value weight = b.create(loc, weightType, weightDynSizes); + + SmallVector weightFilteredDynSizes; + for (int i = 0; i < weightRank; ++i) + if (weightSizes[i] == ShapedType::kDynamic) + weightFilteredDynSizes.push_back(weightDynSizes[i]); + + Value weight = + b.create(loc, weightType, weightFilteredDynSizes); matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes, /*transposed=*/true); @@ -229,13 +230,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, SmallVector(weightRank, one), init, [&](OpBuilder &b, Location loc, ValueRange localIVs, ValueRange accs) { - b.create( - loc, init, - [&](OpBuilder &b, Location loc, Value elem, Value acc) { - Value x = b.create(loc, weight, localIVs); - Value max = b.create(loc, x, acc); - b.create(loc, max); - }); + auto reduceOp = b.create(loc, init); + // Build reduce body. + Block &reductionBody = reduceOp.getReductions()[0].front(); + auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody); + Value acc = reductionBody.getArgument(0); + Value x = + bodyBuilder.create(loc, weight, localIVs); + Value max = bodyBuilder.create(loc, x, acc); + bodyBuilder.create(loc, max); }) .getResult(0); // weight = (weight - max(weight)) / math.sqrt(querySizes[-1]) @@ -247,7 +250,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, x = b.create(loc, x, globalMax); x = b.create(loc, x, scaleFactor); b.create(loc, x, weight, localIVs); - b.create(loc); }); // calculate exp(weight) SmallVector min(weightRank, zero), @@ -258,14 +260,18 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value x = b.create(loc, weight, localIVs); x = b.create(loc, x); b.create(loc, x, weight, localIVs); - b.create(loc); }); + + llvm::SmallVector expWeightDynDims(weightFilteredDynSizes); + if (weightSizes.back() == ShapedType::kDynamic) + expWeightDynDims.resize(expWeightDynDims.size() - 1); + Value expWeightSum = b.create( loc, MemRefType::get( SmallVector(weightSizes.begin(), weightSizes.end() - 1), elementType), - SmallVector{weightDynSizes.begin(), weightDynSizes.end() - 1}); + expWeightDynDims); b.create( loc, SmallVector(weightRank - 1, zero), SmallVector{weightDynSizes.begin(), weightDynSizes.end() - 1}, @@ -290,7 +296,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value y = b.create(loc, weight, coords); Value sum = b.create(loc, x, y); b.create(loc, sum, expWeightSum, outsideDims); - b.create(loc); }); }); // calculate exp(weight) / sum(exp(weight)) @@ -305,7 +310,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value sum = b.create(loc, expWeightSum, sumIVs); x = b.create(loc, x, sum); b.create(loc, x, weight, localIVs); - b.create(loc); }); // output = weight @ value @@ -505,12 +509,32 @@ LogicalResult ScanOp::fold(FoldAdaptor adaptor, //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// +static Type getComplexElementTypeOrSelf(Type ty) { + if (auto complex = dyn_cast_or_null(ty)) + return complex.getElementType(); + return ty; +} + +static bool isInvalid(ArrayRef dimsPos, int64_t rank) { + // early exit. + if (static_cast(dimsPos.size()) > rank) + return true; + DenseSet uniqued; + for (int64_t dim : dimsPos) + uniqued.insert(dim); + if (static_cast(dimsPos.size()) != uniqued.size()) + return true; + return llvm::any_of( + dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; }); +} + LogicalResult ScatterOp::verify() { + Operation *op = getOperation(); if (getInputs().size() != 2) { - return emitOpError("expected two input operands"); + return op->emitOpError("expected two input operands"); } if (getOutputs().size() != 1) { - return emitOpError("expected one output operand"); + return op->emitOpError("expected one output operand"); } auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { return t1.getShape()[dim] == t2.getShape()[dim]; @@ -522,10 +546,19 @@ LogicalResult ScatterOp::verify() { return emitOpError("expected indices to be of rank 2 of i32 element type"); } auto indexDepth = getIndexDepth(); - if (indexDepth == ShapedType::kDynamic) { + if (ShapedType::isDynamic(indexDepth)) { return emitOpError("expected index depth is static"); } + ArrayRef dimMap = getDimensionMap(); + if (static_cast(dimMap.size()) != indexDepth) { + return op->emitOpError("invalid number of dimension map entries "); + } + + auto originalType = getOriginalType(); + if (isInvalid(dimMap, originalType.getRank())) + return op->emitOpError("dimension map is invalid"); + // The first dimension of the indices should match the first dimension of the // output. They indicate to the number of updates. auto updateType = getUpdateType(); @@ -536,7 +569,6 @@ LogicalResult ScatterOp::verify() { return emitOpError( "mismatch in shape of indices and update value at dim#0"); } - auto originalType = getOriginalType(); if (updateType.getRank() - 1 > originalType.getRank()) { return emitOpError( "update value rank exceeds the rank of the original value"); @@ -549,7 +581,7 @@ LogicalResult ScatterOp::verify() { "index depth and update value does not cover rank of original value"); } - // Validate the non-indexed update dims covier the full slice size of the + // Validate the non-indexed update dims cover the full slice size of the // original tensor. int64_t fullSliceDims = originalType.getRank() - indexDepth; for (auto it : @@ -558,10 +590,11 @@ LogicalResult ScatterOp::verify() { updateType.getRank()))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) != - originalType.getDimSize(originalDim)) { - return emitOpError("mismatch in shape of update value dim#") - << updateDim << " and original value at dim#" << originalDim; + if (!originalType.isDynamicDim(originalDim) && + updateType.getDimSize(updateDim) > + originalType.getDimSize(originalDim)) { + return op->emitOpError("shape of update value dim#") + << updateDim << " exceeds original value at dim#" << originalDim; } } @@ -572,23 +605,25 @@ LogicalResult ScatterOp::verify() { llvm::seq(1, updateType.getRank() - fullSliceDims))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) > - originalType.getDimSize(originalDim)) { - return emitOpError("indexed shape of update value dim#") + if (!originalType.isDynamicDim(originalDim) && + updateType.getDimSize(updateDim) > + originalType.getDimSize(originalDim)) { + return op->emitOpError("indexed shape of update value dim#") << updateDim << " exceeds original value at dim#" << originalDim << " " << updateType.getDimSize(updateDim) << " " << originalType.getDimSize(originalDim); } } - Region &thisRegion = getRegion(); - Block *body = &thisRegion.front(); + Region ®ion = this->getRegion(); + Block *body = ®ion.front(); if (body->getNumArguments() != 2) { - return emitOpError("expected region to have two arguments"); + return op->emitOpError("expected region to have two arguments"); } Type arg0Type = body->getArgument(0).getType(); Type arg1Type = body->getArgument(1).getType(); - if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { + if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() || + !getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) { return emitOpError( "expected region to have scalar argument of integer or float types"); } @@ -680,14 +715,16 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, starts[it.index() + offset] = it.value(); } + ArrayRef dimMap = getDimensionMap(); for (auto i : llvm::seq(0, indexDepth)) { loadIndices.back() = b.create(loc, i); Value idx = b.create(loc, indices(), loadIndices); - Value cast = b.create(loc, b.getIndexType(), idx); + Value ret = b.create(loc, b.getIndexType(), idx); - if (starts[i]) - cast = b.create(loc, cast, starts[i]); - starts[i] = cast; + auto dim = dimMap[i]; + if (starts[dim]) + ret = b.create(loc, ret, starts[dim]); + starts[dim] = ret; } Value init = b.create(loc, original(), starts); diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 64352ad1d5ce..1e8c91e8afd4 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -87,7 +87,8 @@ static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter, ValueRange outputs) { SmallVector newOperands = inputs; newOperands.append(outputs.begin(), outputs.end()); - return cast(tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); + return cast( + tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); } /// Generic conversion pattern that matches any TMTensorOp. This avoids template diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 5c90df8e6ac4..e7fcbb434a2c 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -157,7 +157,7 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, intValue); } } - + if (type.isa()) { return builder.create(loc, value.cast()); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e6b29ca98060..fff872b32198 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6,9 +6,10 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// - +#define DEBUG_TYPE "torch-mlir-torch-dialect" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" @@ -203,8 +204,8 @@ static Value getScalarFloatValue(Value input, Location loc, //===----------------------------------------------------------------------===// LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto func = - symbolTable.lookupNearestSymbolFrom(*this, getFunctionAttr()); + auto func = symbolTable.lookupNearestSymbolFrom( + *this, getFunctionAttr()); if (!func) return emitError() << "'@" << getFunction() << "' does not reference a valid function"; @@ -453,11 +454,13 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // If the condition is constant, delete the dead branch and inline the live // branch. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { - auto constantBool = op.getCondition().getDefiningOp(); + auto constantBool = + op.getCondition().getDefiningOp(); if (!constantBool) return rewriter.notifyMatchFailure(op, "non-constant condition"); - replaceOpWithRegion( - rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion()); + replaceOpWithRegion(rewriter, op, + constantBool.getValue() ? op.getThenRegion() + : op.getElseRegion()); return success(); }); // If the thenRegion and elseRegion yield the same Value's, then use those @@ -515,14 +518,16 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, continue; newResultTypes.push_back(op->getResult(i).getType()); } - auto newIf = - rewriter.create(op->getLoc(), newResultTypes, op.getCondition()); + auto newIf = rewriter.create(op->getLoc(), newResultTypes, + op.getCondition()); rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); - newIf.getThenRegion().front().getTerminator()->eraseOperands(resultsToErase); - newIf.getElseRegion().front().getTerminator()->eraseOperands(resultsToErase); + newIf.getThenRegion().front().getTerminator()->eraseOperands( + resultsToErase); + newIf.getElseRegion().front().getTerminator()->eraseOperands( + resultsToErase); SmallVector replacementValues; for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) { if (resultsToErase[i]) @@ -548,8 +553,8 @@ void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns, return failure(); if (value) { - rewriter.eraseOp(op); - return success(); + rewriter.eraseOp(op); + return success(); } // Even if the condition is statically false, the assert might never be // executed. @@ -715,6 +720,8 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { + if (getOperand().getType() != getResult().getType()) + return nullptr; if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(); @@ -727,6 +734,8 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { + if (getOperand(0).getType() != getResult().getType()) + return nullptr; if (auto tensorType = getOperand(0).getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(0); @@ -734,18 +743,6 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenRoundOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { - if (auto selfType = getSelf().getType().dyn_cast()) { - if (selfType.hasDtype() && selfType.getDtype().isa()) - return getSelf(); - } - return nullptr; -} - //===----------------------------------------------------------------------===// // AtenToDtypeOp //===----------------------------------------------------------------------===// @@ -892,10 +889,10 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, auto rhs = op.getOther(); auto getRhsDevice = rewriter.create(op.getLoc(), rhs); auto getRhsDtype = rewriter.create(op.getLoc(), rhs); - rewriter.replaceOpWithNewOp( - op, op.getType(), lhs, getRhsDevice.getResult(), - getRhsDtype.getResult(), op.getNonBlocking(), - op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, getRhsDevice.getResult(), + getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(), + op.getMemoryFormat()); return success(); }); } @@ -911,6 +908,8 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { auto resType = getType().dyn_cast(); if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1) return nullptr; + if (inputType != resType) + return nullptr; // Fold when both the input tensor and result are unity rank tensors. return getOperand(0); } @@ -988,7 +987,7 @@ void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // `aten.max.other` -> `aten.maximum` patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getOther()); + op.getOther()); return success(); }); } @@ -1093,6 +1092,177 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, return success(); } +//===----------------------------------------------------------------------===// +// NAry folder helpers +//===----------------------------------------------------------------------===// + +static bool checkSameDTypes(llvm::ArrayRef attrs) { + bool allFp = true; + bool allInt = true; + + for (auto attr : attrs) { + if (!attr) + return false; + + Type attrty; + if (auto dense = dyn_cast_or_null(attr)) + attrty = dense.getType(); + if (auto fp = dyn_cast_or_null(attr)) + attrty = fp.getType(); + if (auto integer = dyn_cast_or_null(attr)) + attrty = integer.getType(); + if (auto shaped = dyn_cast_or_null(attrty)) + attrty = shaped.getElementType(); + allFp &= isa(attrty); + allInt &= isa(attrty); + } + + return allFp || allInt; +} + +static bool checkAllSplats(llvm::ArrayRef attrs) { + for (auto attr : attrs) { + if (auto dense = dyn_cast_or_null(attr)) { + if (!dense.isSplat()) + return false; + } + } + + return true; +} + +llvm::SmallVector getFoldValueAtIndexFp(llvm::ArrayRef attrs, + int64_t idx = 0) { + llvm::SmallVector splattrs; + + for (auto attr : attrs) { + if (auto dense = dyn_cast(attr)) { + if (dense.isSplat()) { + splattrs.push_back(dense.getSplatValue().convertToDouble()); + } else { + splattrs.push_back(dense.getValues()[idx].convertToDouble()); + } + } else if (auto intattr = dyn_cast(attr)) { + splattrs.push_back(intattr.getValueAsDouble()); + } else { + return {}; + } + } + + return splattrs; +} + +llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, + int64_t bitwidth, + int64_t idx = 0) { + llvm::SmallVector splattrs; + + for (auto attr : attrs) { + bool isunsigned = false; + if (auto dense = dyn_cast(attr)) { + isunsigned = dyn_cast(dense.getElementType()).isUnsigned(); + if (dense.isSplat()) { + splattrs.push_back(dense.getSplatValue()); + } else { + splattrs.push_back(dense.getValues()[idx]); + } + } else if (auto intattr = dyn_cast(attr)) { + isunsigned = cast(intattr.getType()).isUnsigned(); + splattrs.push_back(intattr.getValue()); + } else { + return {}; + } + + auto &apint = splattrs.back(); + if (apint.getBitWidth() < bitwidth) { + if (isunsigned) { + apint = apint.zextOrTrunc(bitwidth); + } else { + apint = apint.sextOrTrunc(bitwidth); + } + } + } + + return splattrs; +} + +using NAryFoldFpOperator = std::function)>; +using NAryFoldIntOperator = std::function)>; + +static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, + NAryFoldFpOperator fpFolder, + NAryFoldIntOperator intFolder) { + constexpr int64_t maxFold = 16; + if (!checkSameDTypes(operands)) + return nullptr; + + auto resultTy = dyn_cast(ty); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes()) + return nullptr; + + auto dty = resultTy.getDtype(); + auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + + auto fpTy = dyn_cast(dty); + auto intTy = dyn_cast(dty); + if (!fpTy && !intTy) + return nullptr; + + bool allSplats = checkAllSplats(operands); + bool withinMaxFold = + resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold; + + if (!allSplats && !withinMaxFold) + return nullptr; + + // We do not support broadcasting in the non-splat case so validate same + // shaped inputs / outputs: + if (!allSplats) { + auto resultShape = resultBTy.getShape(); + for (int i = 0, s = operands.size(); i < s; ++i) { + if (auto dense = dyn_cast(operands[i])) { + if (dense.isSplat()) + continue; + auto operandShape = cast(dense.getType()).getShape(); + if (operandShape.size() != resultShape.size()) + return nullptr; + for (int i = 0, s = operandShape.size(); i < s; ++i) + if (operandShape[i] != resultShape[i]) + return nullptr; + } + } + } + + const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements(); + + if (fpTy) { + llvm::SmallVector folded; + for (int i = 0, s = numValues; i < s; ++i) { + auto inputs = getFoldValueAtIndexFp(operands, i); + double fold = fpFolder(inputs); + + APFloat val(fold); + bool unused; + val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + folded.push_back(val); + } + return DenseElementsAttr::get(resultBTy, folded); + } + + if (intTy) { + llvm::SmallVector folded; + for (int i = 0, s = numValues; i < s; ++i) { + auto inputs = + getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i); + folded.push_back(intFolder(inputs)); + } + return DenseElementsAttr::get(resultBTy, folded); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenAddTensorOp //===----------------------------------------------------------------------===// @@ -1103,6 +1273,20 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] + (inputs[1] * inputs[2]); + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] + (inputs[1] * inputs[2]); + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenAddScalarOp //===----------------------------------------------------------------------===// @@ -1123,6 +1307,20 @@ void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] - (inputs[1] * inputs[2]); + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] - (inputs[1] * inputs[2]); + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenSubScalarOp //===----------------------------------------------------------------------===// @@ -1153,20 +1351,350 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return inputs[0] * inputs[1]; + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return inputs[0] * inputs[1]; + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// -// AtenFloorOp +// AtenEqTensorOp //===----------------------------------------------------------------------===// -void AtenFloorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) { - patterns.add(+[](AtenFloorOp op, PatternRewriter &rewriter) { - auto outputTy = op.getType().dyn_cast(); - if (outputTy && outputTy.hasDtype() && - outputTy.getDtype().isa()) { - rewriter.replaceOp(op, op.getSelf()); - return success(); + +OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { + constexpr int64_t kMaxFold = 16; + auto ty = dyn_cast(getType()); + if (!ty || !ty.hasDtype() || !ty.hasSizes()) + return nullptr; + + auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + if (!bty.hasStaticShape()) + return nullptr; + + if (getSelf() == getOther()) + return DenseElementsAttr::get(bty, + IntegerAttr::get(bty.getElementType(), 1)); + + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = dyn_cast_or_null(adaptor.getOther()); + if (!self || !other) + return nullptr; + + auto selfTy = dyn_cast(self.getType()); + auto otherTy = dyn_cast(other.getType()); + if (!selfTy || !otherTy || + selfTy.getElementType() != otherTy.getElementType()) + return nullptr; + + // If both values are splats we can just compute the output value as a splat. + if (self.isSplat() && other.isSplat()) { + if (isa(selfTy.getElementType())) { + APFloat lhsFp = self.getSplatValue(); + APFloat rhsFp = other.getSplatValue(); + bool eq = lhsFp.compare(rhsFp) == APFloat::cmpEqual; + return DenseElementsAttr::get(bty, eq); } - return failure(); - }); + + if (isa(selfTy.getElementType())) { + APInt lhsInt = self.getSplatValue(); + APInt rhsInt = other.getSplatValue(); + bool eq = lhsInt == rhsInt; + return DenseElementsAttr::get(bty, eq); + } + + return nullptr; + } + + if (selfTy != otherTy || bty.getNumElements() > kMaxFold) + return nullptr; + + if (isa(selfTy.getElementType())) { + auto extract = [bty](DenseElementsAttr attr) { + llvm::SmallVector vals; + if (attr.isSplat()) { + vals.resize(bty.getNumElements(), attr.getSplatValue()); + return vals; + } + + for (auto fp : attr.getValues()) { + vals.push_back(fp); + } + return vals; + }; + + llvm::SmallVector lhsFp = extract(self); + llvm::SmallVector rhsFp = extract(other); + llvm::SmallVector vals(bty.getNumElements()); + for (int i = 0, s = bty.getNumElements(); i < s; ++i) { + vals[i] = lhsFp[i].compare(rhsFp[i]) == APFloat::cmpEqual; + } + + return DenseElementsAttr::get(bty, vals); + } + + if (isa(selfTy.getElementType())) { + auto extract = [bty](DenseElementsAttr attr) { + llvm::SmallVector vals; + if (attr.isSplat()) { + vals.resize(bty.getNumElements(), attr.getSplatValue()); + return vals; + } + + for (auto fp : attr.getValues()) { + vals.push_back(fp); + } + return vals; + }; + + llvm::SmallVector lhsInt = extract(self); + llvm::SmallVector rhsInt = extract(other); + llvm::SmallVector vals(bty.getNumElements()); + for (int i = 0, s = bty.getNumElements(); i < s; ++i) { + vals[i] = lhsInt[i] == rhsInt[i]; + } + + return DenseElementsAttr::get(bty, vals); + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenLeScalarOp +//===----------------------------------------------------------------------===// + +using ComparisonFoldFpOperator = std::function; +using ComparisonFoldIntOperator = std::function; + +static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, + ValueTensorType resultTy, + ComparisonFoldFpOperator fpFolder, + ComparisonFoldIntOperator intFolder) { + constexpr int64_t kMaxFold = 16; + if (!lhs || !rhs || !resultTy) + return nullptr; + if (!resultTy.hasSizes() || !resultTy.hasDtype()) + return nullptr; + + for (auto size : resultTy.getSizes()) + if (size == Torch::kUnknownSize) + return nullptr; + + auto ctx = lhs.getContext(); + auto resultETy = resultTy.getDtype(); + auto tensorETy = cast(lhs.getType()).getElementType(); + if (lhs.isSplat()) { + if (auto intAttr = dyn_cast(rhs)) { + auto unsign = cast(tensorETy).isUnsigned(); + auto scalarAP = intAttr.getValue(); + auto tensorAP = lhs.getSplatValue().getValue(); + tensorAP = APInt( + scalarAP.getBitWidth(), + unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); + auto resultBool = intFolder(tensorAP, scalarAP, unsign); + auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + resultAP); + } + + if (auto floatAttr = dyn_cast(rhs)) { + APFloat scalarAP = floatAttr.getValue(); + APFloat tensorAP = lhs.getSplatValue().getValue(); + auto resultBool = + fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); + auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + resultAP); + } + return nullptr; + } + + int64_t count = 1; + for (auto size : resultTy.getSizes()) + count *= size; + + if (count > kMaxFold) + return nullptr; + + if (auto intAttr = dyn_cast(rhs)) { + auto unsign = cast(tensorETy).isUnsigned(); + llvm::SmallVector values; + for (auto tensorAP : lhs.getValues()) { + auto scalarAP = intAttr.getValue(); + tensorAP = APInt( + scalarAP.getBitWidth(), + unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); + auto resultBool = intFolder(tensorAP, scalarAP, unsign); + values.push_back(resultBool); + } + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + values); + } + + if (auto floatAttr = dyn_cast(rhs)) { + llvm::SmallVector values; + for (auto tensorAP : lhs.getValues()) { + APFloat scalarAP = floatAttr.getValue(); + auto resultBool = + fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); + values.push_back(resultBool); + } + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + values); + } + + return nullptr; +} + +OpFoldResult AtenLeScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs <= rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.ule(rhs) : lhs.sle(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenLtScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenLtScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs < rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.ult(rhs) : lhs.slt(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenGtScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenGtScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs > rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.ugt(rhs) : lhs.sgt(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenGeScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenGeScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs >= rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.uge(rhs) : lhs.sge(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenEqScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs == rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return lhs.eq(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenNeScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs != rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return lhs.ne(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenFloorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// AtenCeilOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// AtenRoundOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; } //===----------------------------------------------------------------------===// @@ -1241,16 +1769,48 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// -// AtenSizeOp +// AtenFloatImplicitOp //===----------------------------------------------------------------------===// +void AtenFloatImplicitOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value a = op.getA(); + Value scalarValue = getScalarFloatValue(a, loc, rewriter); + if (!scalarValue) + return failure(); + rewriter.replaceOp(op, scalarValue); + return success(); + }); +} -// Traces at most 6 parents of `value` to determine the tensor type with known -// dimension size or returns failure if such a type was not found. If `dim` is -// `None`, then all dimension's sizes must be known. -static FailureOr -traceKnownSizeTensorType(Value value, std::optional dim) { - // Function to check if we found a type that contains the queried information. - auto foundType = [](BaseTensorType tensorType, std::optional(dim)) { +//===----------------------------------------------------------------------===// +// AtenIntImplicitOp +//===----------------------------------------------------------------------===// +void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value a = op.getA(); + Value scalarValue = getScalarIntValue(a, loc, rewriter); + if (!scalarValue) + return failure(); + rewriter.replaceOp(op, scalarValue); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// AtenSizeOp +//===----------------------------------------------------------------------===// + +// Traces at most 6 parents of `value` to determine the tensor type with known +// dimension size or returns failure if such a type was not found. If `dim` is +// `None`, then all dimension's sizes must be known. +static FailureOr +traceKnownSizeTensorType(Value value, std::optional dim) { + // Function to check if we found a type that contains the queried information. + auto foundType = [](BaseTensorType tensorType, std::optional(dim)) { if (!tensorType.hasSizes()) return false; @@ -1314,6 +1874,41 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenSelectIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto ty = dyn_cast(getType()); + if (!self || !ty || !ty.hasDtype() || !ty.hasSizes()) + return nullptr; + + auto selfTy = cast(self.getType()); + auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + if (!bty.hasStaticShape()) + return nullptr; + + if (self.isSplat()) + return DenseElementsAttr::get(bty, self.getSplatValue()); + + auto dimAttr = dyn_cast_or_null(adaptor.getDim()); + auto indexAttr = dyn_cast_or_null(adaptor.getIndex()); + if (!dimAttr || !indexAttr || bty.getNumElements() != 1) + return nullptr; + + auto dim = dimAttr.getInt(); + auto index = indexAttr.getInt(); + + for (int i = 0, s = selfTy.getRank(); i < s; ++i) { + if (i != dim && selfTy.getDimSize(i) != 1) + return nullptr; + } + + auto splattr = self.getValues()[index]; + return DenseElementsAttr::get(bty, splattr); +} + //===----------------------------------------------------------------------===// // AtenSizeIntOp //===----------------------------------------------------------------------===// @@ -1453,7 +2048,11 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, // AtenDetachOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } +OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { + if (getSelf().getType() != getResult().getType()) + return {}; + return getSelf(); +} //===----------------------------------------------------------------------===// // AtenEmptyMemoryFormatOp @@ -1661,6 +2260,19 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenCloneOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) { + // note: memory_format would be ignored + if (llvm::dyn_cast(getSelf().getType())) { + // self should have value semantics + return getSelf(); + } + return {}; +} + //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// @@ -1695,6 +2307,52 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenSortOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + auto operand = getSelf(); + auto operandType = dyn_cast(operand.getType()); + if (!operandType || !operandType.hasSizes()) + return failure(); + + // only ValueTensorType has toBuiltinTensor + auto indicesTensorType = dyn_cast(getResult(1).getType()); + if (!indicesTensorType) + return failure(); + + if (!indicesTensorType.hasDtype()) + return failure(); + auto indicesType = + indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); + if (!indicesType || !indicesType.hasStaticShape()) + return failure(); + + bool unaryDim = false; + IntegerAttr dimAttribute = dyn_cast_if_present(adaptor.getDim()); + if (!dimAttribute) + return failure(); + int64_t dimInt = dimAttribute.getValue().getSExtValue(); + if (dimInt < 0) + dimInt += operandType.getSizes().size(); + if (dimAttribute) { + unaryDim = operandType.getSizes()[dimInt] == 1; + } + + OpBuilder builder(getContext()); + if (unaryDim || llvm::all_of(operandType.getSizes(), + [](int64_t dim) { return dim == 1; })) { + results.push_back(operand); + results.push_back(DenseElementsAttr::get( + indicesType, builder.getZeroAttr(indicesType.getElementType()))); + return success(); + } + + return failure(); +} + //===----------------------------------------------------------------------===// // NonValueTensorLiteralOp //===----------------------------------------------------------------------===// @@ -1941,7 +2599,7 @@ void Torch::ConstantFloatOp::getAsmResultNames( // float string representation). SmallVector buf; getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, - /*TruncateZero=*/false); + /*TruncateZero=*/false); auto isValidMLIRIdentifierChar = [](char c) { return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' || c == '-'; @@ -2052,7 +2710,8 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( // compiler treat the size as having value semantics? // There's a small number of such ops, and they are marked as `inplace_view` // in PyTorch's `native_functions.yaml` file. - rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), op.getIdx()); + rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), + op.getIdx()); return success(); }); } @@ -2080,11 +2739,13 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) { - auto lhsListConstruct = op.getA().getDefiningOp(); + auto lhsListConstruct = + op.getA().getDefiningOp(); if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct)) return failure(); - auto rhsListConstruct = op.getB().getDefiningOp(); + auto rhsListConstruct = + op.getB().getDefiningOp(); if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct)) return failure(); @@ -2202,7 +2863,8 @@ LogicalResult PrimTupleConstructOp::verify() { void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) { - auto tupleConstruct = op.getTup().getDefiningOp(); + auto tupleConstruct = + op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); @@ -2252,7 +2914,8 @@ void PrimUninitializedOp::getCanonicalizationPatterns( void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { - auto tupleConstruct = op.getTup().getDefiningOp(); + auto tupleConstruct = + op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); @@ -2407,9 +3070,7 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, // AtenAliasOp //===----------------------------------------------------------------------===// -OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { - return getOperand(); -} +OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); } //===----------------------------------------------------------------------===// // AtenFloordivIntOp @@ -2453,10 +3114,92 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { + // We set a maximum folding size of 16. This is a reasonable upper limit + // for shape computations. + constexpr int64_t kMaxFoldSize = 16; auto list = getOperand(0).getDefiningOp(); - if (!list || !list->hasOneUse() || list.getElements().size() != 1) + if (!list) return nullptr; - return list.getElements()[0]; + + auto elements = list.getElements(); + if (elements.size() == 1 && elements[0].getType() == getResult().getType()) + return list.getElements()[0]; + + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) + return nullptr; + + auto bResultTy = resultTy.toBuiltinTensor(); + if (!bResultTy.hasStaticShape() || bResultTy.getNumElements() > kMaxFoldSize) + return nullptr; + + auto dimAttr = dyn_cast_or_null(adaptor.getDim()); + if (!dimAttr) + return nullptr; + auto dim = dimAttr.getValue().getSExtValue(); + dim += dim < 0 ? bResultTy.getRank() : 0; + + for (int i = 0, s = bResultTy.getRank(); i < s; ++i) { + if (i == dim) + continue; + if (bResultTy.getDimSize(i) != 1) + return nullptr; + } + + llvm::SmallVector values; + for (auto operand : list.getOperands()) { + DenseElementsAttr dattr; + if (!matchPattern(operand, m_Constant(&dattr))) + return nullptr; + + auto oty = dyn_cast(dattr.getType()); + if (!oty) + return nullptr; + + if (dattr.isSplat()) { + for (int i = 0, s = oty.getDimSize(dim); i < s; ++i) + values.push_back(dattr.getSplatValue()); + } else { + auto evals = dattr.getValues(); + for (int i = 0, s = oty.getDimSize(dim); i < s; ++i) + values.push_back(evals[i]); + } + } + + return DenseElementsAttr::get(bResultTy.clone(resultTy.getDtype()), values); +} + +void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenCatOp op, PatternRewriter &rewriter) { + auto list = op.getTensors().getDefiningOp(); + auto resultTy = dyn_cast(op.getType()); + if (!list || !resultTy) + return failure(); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return failure(); + + llvm::SmallVector filtered; + for (auto operand : list.getOperands()) { + auto operandTy = dyn_cast(operand.getType()); + if (!operandTy || !operandTy.hasSizes()) + return failure(); + int64_t adim = dim < 0 ? dim + operandTy.getSizes().size() : dim; + if (operandTy.getSizes()[adim] != 0) + filtered.push_back(operand); + } + + if (filtered.size() == list.getNumOperands()) + return failure(); + + auto newlist = rewriter.create( + op.getLoc(), list.getType(), filtered); + rewriter.replaceOpWithNewOp(op, op.getType(), newlist, + op.getDim()); + return success(); + }); } //===----------------------------------------------------------------------===// @@ -2466,17 +3209,32 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || + !outType.hasDtype()) return nullptr; - if (inType.getSizes().size() != outType.getSizes().size() || - (!isAssumingStrictSymbolicShapes((*this)->getBlock()) && - (!inType.areAllSizesKnown() || !outType.areAllSizesKnown()))) + + if (!inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; - for (size_t i = 0; i < inType.getSizes().size(); ++i) { - if (inType.getSizes()[i] != outType.getSizes()[i]) - return nullptr; + + auto inSizes = inType.getSizes(); + auto outSizes = outType.getSizes(); + if (inSizes.size() == outSizes.size()) { + bool sameSizes = true; + for (int i = 0, s = inSizes.size(); i < s; ++i) + sameSizes &= inSizes[i] == outSizes[i]; + + if (sameSizes) + return getOperand(0); } - return getOperand(0); + + auto selfAttr = dyn_cast_or_null(adaptor.getSelf()); + if (!selfAttr) + return nullptr; + if (!selfAttr.isSplat()) + return nullptr; + + auto attrty = RankedTensorType::get(outType.getSizes(), outType.getDtype()); + return DenseElementsAttr::get(attrty, selfAttr.getSplatValue()); } //===----------------------------------------------------------------------===// @@ -2484,22 +3242,75 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { - int64_t start, end, step; - if (matchPattern(getStart(), m_TorchConstantInt(&start)) && - matchPattern(getEnd(), m_TorchConstantInt(&end)) && - matchPattern(getStep(), m_TorchConstantInt(&step)) - && step == 1 - && start == 0 - && end == std::numeric_limits::max()) - return getOperand(0); - - auto inType = getOperand(0).getType().dyn_cast(); - auto outType = getResult().getType().dyn_cast(); - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + DenseElementsAttr input = + dyn_cast_or_null(adaptor.getSelf()); + IntegerAttr start = dyn_cast_or_null(adaptor.getStart()); + IntegerAttr end = dyn_cast_or_null(adaptor.getEnd()); + IntegerAttr step = dyn_cast_or_null(adaptor.getStep()); + IntegerAttr dim = dyn_cast_or_null(adaptor.getDim()); + + if (start && end && step && step.getValue().getSExtValue() == 1 && + start.getValue().getSExtValue() == 0 && + end.getValue().getSExtValue() == std::numeric_limits::max()) + return getOperand(0); + + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || + !inType.hasDtype() || !outType.hasDtype() || + inType.getDtype() != outType.getDtype()) return nullptr; + if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; + + if (input && input.isSplat()) + return DenseElementsAttr::get( + outType.toBuiltinTensor().clone(inType.getDtype()), + input.getSplatValue()); + + int count = 1; + for (auto dim : outType.getSizes()) + count = count * dim; + + if (count == 0) + return {}; + + if (!dim) + return nullptr; + int64_t dimInt = dim.getValue().getSExtValue(); + if (dimInt < 0) + dimInt += inType.getSizes().size(); + + bool unaryNonDim = true; + for (int i = 0, s = outType.getSizes().size(); i < s; ++i) + unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt; + + // Fold the slice if the output tensor is relatively small, currently + // coded to 16: + if (input && start && step && dim && count < 16 && unaryNonDim && + count < 16) { + int64_t inCount = input.getNumElements(); + int64_t begin = start.getValue().getSExtValue(); + int64_t stride = step.getValue().getSExtValue(); + if (stride < 1) + return {}; + int64_t limit = end.getValue().getSExtValue(); + begin = begin < 0 ? begin + inCount : begin; + limit = limit < 0 ? limit + inCount : limit; + limit = limit < 0 ? inType.getSizes()[dimInt] : limit; + limit = std::min(limit, inType.getSizes()[dimInt]); + + llvm::SmallVector values; + for (int i = begin; i < limit; i += stride) + values.push_back(input.getValues()[i]); + + return DenseElementsAttr::get( + outType.toBuiltinTensor().clone(inType.getDtype()), values); + } + + // If the input and output shapes are the same we can just fold: for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; @@ -2559,6 +3370,25 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) { [](double a, double b) -> double { return a + b; }); } +//===----------------------------------------------------------------------===// +// AtenMulOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + + if (adaptor.getA().isa() && adaptor.getB().isa()) { + return atenBinaryIntOperatorFoldHelper( + adaptor.getOperands(), + [](int64_t a, int64_t b) -> int64_t { return a * b; }); + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), + [](double a, double b) -> double { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// @@ -2804,6 +3634,55 @@ void AtenDeviceWithIndexOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { + // If a torch.aten.tensor op is initialized by a list with a constant, single + // element, fold it into a torch.vtensor.literal + auto resultTy = dyn_cast(getType()); + Type eTy = resultTy.getDtype(); + ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + + SmallVector data; + if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && + data.size() == 1) { + Attribute attribute = IntegerAttr::get(eTy, data[0]); + return DenseElementsAttr::get(shapedTy, attribute); + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) { + auto selfTy = dyn_cast(getSelf().getType()); + auto resultTy = dyn_cast(getType()); + if (!selfTy || !resultTy || !selfTy.hasSizes() || !resultTy.hasDtype() || + !resultTy.hasSizes()) + return {}; + + llvm::SmallVector values(selfTy.getSizes()); + if (llvm::any_of(values, [](int64_t d) { return d == Torch::kUnknownSize; })) + return {}; + + auto dty = dyn_cast(resultTy.getDtype()); + if (!dty) + return {}; + + llvm::SmallVector attrs; + for (auto val : values) { + attrs.push_back(IntegerAttr::get(dty, val)); + } + + auto attrty = RankedTensorType::get(resultTy.getSizes(), dty); + return DenseElementsAttr::get(attrty, attrs); +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// @@ -2813,7 +3692,7 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); - if (auto tensorIntOp = getA().getDefiningOp()) + if (auto tensorIntOp = getA().getDefiningOp()) return tensorIntOp.getT(); return nullptr; } @@ -2860,6 +3739,240 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenIndexSelectOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { + auto self = getSelf(); + auto index = getIndex(); + auto selfTy = dyn_cast(self.getType()); + auto indexTy = dyn_cast(index.getType()); + auto resultTy = dyn_cast(getType()); + if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() || + !indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() || + !indexTy.hasDtype() || !resultTy.hasDtype()) + return nullptr; + + auto selfSizes = selfTy.getSizes(); + auto indexSizes = indexTy.getSizes(); + auto resultSizes = resultTy.getSizes(); + + if (selfTy.getDtype() != resultTy.getDtype() || + selfSizes.size() != resultSizes.size() || indexSizes.size() != 1) + return nullptr; + + // If the selection results in a tensor of the same dimensions as the + // input, the selection must have specified every index of the input, + // so the result is exactly the same as the input. + + bool fullTensor = true; + for (int i = 0, s = selfSizes.size(); i < s; ++i) { + fullTensor &= selfSizes[i] == resultSizes[i]; + fullTensor &= selfSizes[i] != Torch::kUnknownSize; + fullTensor &= resultSizes[i] != Torch::kUnknownSize; + } + + if (fullTensor && indexSizes[0] == 1) + return self; + + // If the input tensor, index dimension, or indexes are non-constant, + // can't fold. + + auto selfAttr = dyn_cast_or_null(adaptor.getSelf()); + auto dimAttr = dyn_cast_or_null(adaptor.getDim()); + auto indexAttr = dyn_cast_or_null(adaptor.getIndex()); + + if (!selfAttr || !dimAttr || !indexAttr) + return {}; + + // If the input's dimensions are all 1 except for one dimension, and if + // there is a single index in the index list (as detected by the result + // dimension being 1), then fold to a <1x1x...x1> tensor literal containing + // a single element. Handles float and int types. + + int64_t dimInt = dimAttr.getInt(); + // If the selected dim is negative, count backwards from the last dim + if (dimInt < 0) + dimInt = selfSizes.size() + dimInt; + assert(uint64_t(dimInt) < selfSizes.size() && + "Selected dim > number of dims"); + + for (int i = 0, s = selfSizes.size(); i < s; ++i) { + if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1) + return nullptr; + } + + // Get the single index value for the selected dimension + auto splatValue = indexAttr.getSplatValue(); + int64_t indexInt = getIntAttrAsSigned(splatValue); + indexInt = indexInt < 0 && selfSizes[dimInt] ? indexInt + selfSizes[dimInt] + : indexInt; + + // Extract the single constant value from the input tensor and turn the + // extracted value into a single-element tensor of the output shape and dtype + Attribute splattr = selfAttr.isSplat() + ? selfAttr.getSplatValue() + : selfAttr.getValues()[indexInt]; + + auto dty = resultTy.getDtype(); + auto attrTy = resultTy.toBuiltinTensor().clone(dty); + if (auto floatAttr = dyn_cast(splattr)) + return DenseElementsAttr::get( + attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); + + if (auto intAttr = dyn_cast(splattr)) { + return DenseElementsAttr::get(attrTy, + IntegerAttr::get(dty, intAttr.getValue())); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenItemOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { + // see if we have a constant tensor + DenseElementsAttr attr; + if (matchPattern(getOperand(), m_Constant(&attr))) { + auto splat = attr.getSplatValue(); + if (auto intAttr = dyn_cast(splat)) { + return getI64IntegerAttr(getContext(), intAttr.getSInt()); + } + if (auto floatAttr = dyn_cast(splat)) { + return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble()); + } + return nullptr; + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenOnesOp, AtenZerosOp, AtenFullOp +//===----------------------------------------------------------------------===// +OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { + return nullptr; + } + + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + return nullptr; + } + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + Attribute attribute = IntegerAttr::get(elementType, 1); + return DenseElementsAttr::get(shapedty, attribute); + } + if (elementType.isa()) { + Attribute attribute = FloatAttr::get(elementType, 1.0); + return DenseElementsAttr::get(shapedty, attribute); + } + return nullptr; +} + +OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { + return nullptr; + } + + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + return nullptr; + } + + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + Attribute attribute = IntegerAttr::get(elementType, 0); + return DenseElementsAttr::get(shapedty, attribute); + } + if (elementType.isa()) { + Attribute attribute = FloatAttr::get(elementType, 0.0); + return DenseElementsAttr::get(shapedty, attribute); + } + + return nullptr; +} + +OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { + return nullptr; + } + + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + ShapedType shapedty = + mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); + + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + int64_t value = 0; + if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) { + Attribute attribute = IntegerAttr::get(elementType, value); + return DenseElementsAttr::get(shapedty, attribute); + } + } + if (elementType.isa()) { + double value = 0.0; + if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) { + Attribute attribute = FloatAttr::get(elementType, value); + return DenseElementsAttr::get(shapedty, attribute); + } + } + return nullptr; +} //===----------------------------------------------------------------------===// // AtenCeilFloatOp //===----------------------------------------------------------------------===// @@ -2871,6 +3984,126 @@ OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenWhereSelfOp +//===----------------------------------------------------------------------===// + +static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { + if (!attr || !ty.hasDtype() || !ty.hasSizes()) + return nullptr; + + auto dty = ty.getDtype(); + + if (auto valueDense = dyn_cast(attr)) { + if (!valueDense.isSplat()) + return nullptr; + auto splattr = valueDense.getSplatValue(); + auto attrty = ty.toBuiltinTensor().clone(dty); + return DenseElementsAttr::get(attrty, splattr); + } + + if (auto intAttr = dyn_cast_or_null(attr)) { + if (!isa(dty)) + return nullptr; + int64_t intval = intAttr.getInt(); + auto attrty = ty.toBuiltinTensor().clone(dty); + return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); + } + + if (auto fpAttr = dyn_cast_or_null(attr)) { + if (!isa(dty)) + return nullptr; + double dblval = fpAttr.getValueAsDouble(); + auto attrty = ty.toBuiltinTensor().clone(dty); + return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); + } + + return nullptr; +} + +OpFoldResult AtenWhereSelfOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto value = getSelf(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + value = getOther(); + valueAttr = adaptor.getOther(); + } + + auto valueTy = dyn_cast(value.getType()); + if (valueTy && valueTy.hasSizes() && valueTy.hasDtype() && + valueTy == resultTy) + return value; + + return getBroadcastedAttr(valueAttr, resultTy); +} + +//===----------------------------------------------------------------------===// +// AtenWhereScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenWhereScalarOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + valueAttr = adaptor.getOther(); + } + + return getBroadcastedAttr(valueAttr, resultTy); +} + +//===----------------------------------------------------------------------===// +// AtenWhereScalarOtherOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenWhereScalarOtherOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + valueAttr = adaptor.getOther(); + } + + return getBroadcastedAttr(valueAttr, resultTy); +} + +//===----------------------------------------------------------------------===// +// AtenWhereScalarSelfOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenWhereScalarSelfOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + valueAttr = adaptor.getOther(); + } + + return getBroadcastedAttr(valueAttr, resultTy); +} + //===----------------------------------------------------------------------===// // PrimMaxIntOp //===----------------------------------------------------------------------===// @@ -2890,6 +4123,30 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); } +//===----------------------------------------------------------------------===// +// PrimNumToTensorScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) { + Attribute a = adaptor.getA(); + auto resultTy = cast(getType()); + if (!a) + return {}; + if (!resultTy.hasDtype() || !resultTy.hasSizes()) + return {}; + + auto dty = resultTy.getDtype(); + if (auto iattr = dyn_cast(a)) { + a = IntegerAttr::get(dty, iattr.getInt()); + } else if (auto fattr = dyn_cast(a)) { + a = FloatAttr::get(dty, fattr.getValueAsDouble()); + } + + auto mlirTensorType = + RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype()); + return SplatElementsAttr::get(mlirTensorType, a); +} + //===----------------------------------------------------------------------===// // PrimMinSelfIntOp //===----------------------------------------------------------------------===// @@ -2988,6 +4245,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenNormScalarOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenNormScalarOp::verify() { + + // Verificaion of input type for torch.aten.norm.Scalar. + // Per PyTorch docs, only float and complex types are valid for norm + // operation. + + auto inTensor = getSelf().getType().cast(); + + // If no dtype is specified, it will default to a float one. + if (!inTensor.hasDtype()) { + return success(); + } + + auto inTensorDtype = inTensor.getDtype(); + + // Check if dtype is one of those supported by norm operation. + // ComplexType will match any torch complex types, but each float must be + // checked individually. + if (!inTensorDtype.isa()) { + return emitOpError( + "expected a float or complex type for input tensor, but got ") + << inTensorDtype; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// AtenPermuteOp +//===----------------------------------------------------------------------===// + LogicalResult AtenPermuteOp::verify() { // Verification of the permute op for input & output dimensions with @@ -3024,7 +4317,6 @@ LogicalResult AtenPermuteOp::verify() { << " elements, the output has rank " << outRank << '.'; } - // Initialization of the reverse permutation. -1 denotes an unknown // permutation index. SmallVector reversePermutation(outRank, -1); @@ -3078,6 +4370,96 @@ LogicalResult AtenPermuteOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenLinalgCrossOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenLinalgCrossOp::verify() { + + auto selfType = getSelf().getType().cast(); + auto otherType = getOther().getType().cast(); + + if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() || + !otherType.hasSizes()) { + return success(); + } + + Type selfDtype = selfType.getDtype(); + Type otherDtype = otherType.getDtype(); + + // the operation succeeds only if both inputs have the same dtype + if (selfDtype != otherDtype) { + return emitOpError("input tensors must have the same dtype, but got ") + << selfDtype << " and " << otherDtype; + } + + // Check if any of the input tensors has torch.bool dtype. + // The operation does not support this type. + // The docs state that only float, double, cfloat and cdouble dtypes are + // supported, but, when testing, it fails only for boolean dtype. Update to + // fit the docs if necessary. + // https://pytorch.org/docs/stable/generated/torch.linalg.cross.html + if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) { + return emitOpError("input tensors must not have bool dtype"); + } + + ArrayRef selfShape = selfType.getSizes(); + ArrayRef otherShape = otherType.getSizes(); + + int64_t selfRank = selfShape.size(); + int64_t otherRank = otherShape.size(); + + // check if both input tensors have the same number of dims + if (selfRank != otherRank) { + return emitOpError("input tensors must have the same number of dimensions, " + "but got ") + << selfRank << " and " << otherRank; + } + + // convert dim to an integer type + int64_t dim; + if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) { + return success(); + } + + // check if dim is in the correct range + if (dim >= selfRank || dim < -selfRank) { + return emitOpError("dim expected to be in rank of [") + << -selfRank << ", " << selfRank - 1 << "], but got " << dim; + } + + // compensate for possible negative dim value + if (dim < 0) { + dim += selfRank; + } + + // check if the size of the dimensions specified by 'dim' is equal to 3 + // (required by the operation) + if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) || + (otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) { + return emitOpError("inputs dimension ") + << dim << " must have length 3, but got " << selfShape[dim] + << " and " << otherShape[dim]; + } + + // Check if there is a disparity between dimension sizes. + // Dimensions at the same index must either have the same size, + // or one of them must be equal to 1. + int32_t i = 0; + for (auto [selfCurrent, otherCurrent] : + llvm::zip_equal(selfShape, otherShape)) { + if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) { + return emitOpError("the size of first tensor (") + << selfCurrent << ") must match the size of second tensor (" + << otherCurrent << ") at dimension " << i + << " or one of them must be 1"; + } + ++i; + } + + return success(); +} + //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index cf832b1b755e..b22c82b8a28f 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -8,10 +8,11 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/DialectImplementation.h" -#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; @@ -184,11 +185,17 @@ static bool isValidTorchDtype(Type dtype) { dtype = dtype.cast().getElementType(); } // Torch quantized types. - if (dtype.isa()) + if (dtype.isa()) return true; // Builtin floating point types. if (dtype.isa()) return true; + if (dtype.isa()) + return true; + + if (dtype.isa()) + return true; // Builtin integer types. if (IntegerType type = dtype.dyn_cast()) { if (type.isSignless() && type.getWidth() == 1) @@ -239,7 +246,7 @@ ValueTensorType BaseTensorType::getWithValueSemantics() const { static LogicalResult verifyTensorType(function_ref emitError, std::optional> optionalSizes, - Type optionalDtype) { + Type optionalDtype, Attribute optionalSparsity) { if (optionalDtype && !isValidTorchDtype(optionalDtype)) { emitError() << "invalid dtype " << optionalDtype << " for !torch.tensor type"; @@ -253,6 +260,24 @@ verifyTensorType(function_ref emitError, } } } + // Verify sparsity encoding against a known type and shape using the encoding + // verification interface. Any implementation emits a diagnostic on failure. + // Also verify sparsity encoding is truly a sparse encoding attrbute. + if (optionalSparsity) { + if (optionalDtype && optionalSizes.has_value()) { + if (auto venc = llvm::dyn_cast_or_null( + optionalSparsity)) { + if (failed(venc.verifyEncoding(optionalSizes.value(), optionalDtype, + emitError))) { + return failure(); + } + } + } + if (!optionalSparsity.isa()) { + emitError() << "invalid sparsity encoding attribute"; + return failure(); + } + } return success(); } @@ -262,7 +287,8 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, if (parser.parseOptionalLess()) return getTensorType(context, /*optionalSizes=*/std::nullopt, - /*optionalDtype=*/Type()); + /*optionalDtype=*/Type(), + /*optionalSparsity=*/Attribute()); bool hasSizes; SmallVector sizes; if (succeeded(parser.parseOptionalStar())) { @@ -307,6 +333,12 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, if (parser.parseType(optionalDtype)) return Type(); } + Attribute optionalSparsity; + if (succeeded(parser.parseOptionalComma())) { + // Explicit encoding. + if (parser.parseAttribute(optionalSparsity)) + return Type(); + } if (parser.parseGreater()) return Type(); std::optional> optionalSizes; @@ -314,15 +346,15 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, optionalSizes.emplace(sizes); if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); }, - optionalSizes, optionalDtype))) + optionalSizes, optionalDtype, optionalSparsity))) return Type(); - return getTensorType(context, optionalSizes, optionalDtype); + return getTensorType(context, optionalSizes, optionalDtype, optionalSparsity); } static void printTensorType(AsmPrinter &printer, std::optional> optionalSizes, - Type optionalDtype) { + Type optionalDtype, Attribute optionalSparsity) { if (!optionalSizes && !optionalDtype) return; printer << "<"; @@ -345,6 +377,10 @@ static void printTensorType(AsmPrinter &printer, printer.printType(optionalDtype); else printer << "unk"; + if (optionalSparsity) { + printer << ","; + printer.printAttribute(optionalSparsity); + } printer << ">"; } @@ -367,8 +403,9 @@ NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { LogicalResult NonValueTensorType::verify(function_ref emitError, std::optional> optionalSizes, - Type optionalDtype) { - return verifyTensorType(emitError, optionalSizes, optionalDtype); + Type optionalDtype, Attribute optionalSparsity) { + return verifyTensorType(emitError, optionalSizes, optionalDtype, + optionalSparsity); } Type NonValueTensorType::parse(AsmParser &parser) { @@ -376,13 +413,15 @@ Type NonValueTensorType::parse(AsmParser &parser) { return parseTensorType( context, parser, [](MLIRContext *context, std::optional> optionalSizes, - Type optionalType) { - return NonValueTensorType::get(context, optionalSizes, optionalType); + Type optionalType, Attribute optionalSparsity) { + return NonValueTensorType::get(context, optionalSizes, optionalType, + optionalSparsity); }); } void NonValueTensorType::print(AsmPrinter &printer) const { - printTensorType(printer, getOptionalSizes(), getOptionalDtype()); + printTensorType(printer, getOptionalSizes(), getOptionalDtype(), + getOptionalSparsity()); } //===----------------------------------------------------------------------===// @@ -407,9 +446,19 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { } else if (auto integerType = dtype.dyn_cast()) { return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); - } else if (dtype.isa()){ + } else if (dtype.isa()) { return dtype; } + + if (isa(dtype)) + return IntegerType::get(context, 8, IntegerType::Signless); + + if (isa(dtype)) + return IntegerType::get(context, 8, IntegerType::Signless); + + if (isa(dtype)) + return IntegerType::get(context, 32, IntegerType::Signless); + emitError(UnknownLoc::get(context)) << "unimplemented: conversion of dtype " << dtype << " to builtin tensor element type"; @@ -424,14 +473,16 @@ TensorType ValueTensorType::toBuiltinTensor() const { Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; - return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType); + return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, + getOptionalSparsity()); } LogicalResult ValueTensorType::verify(function_ref emitError, std::optional> optionalSizes, - Type optionalDtype) { - return verifyTensorType(emitError, optionalSizes, optionalDtype); + Type optionalDtype, Attribute optionalSparsity) { + return verifyTensorType(emitError, optionalSizes, optionalDtype, + optionalSparsity); } Type ValueTensorType::parse(AsmParser &parser) { @@ -439,13 +490,15 @@ Type ValueTensorType::parse(AsmParser &parser) { return parseTensorType( context, parser, [](MLIRContext *context, std::optional> optionalSizes, - Type optionalType) { - return ValueTensorType::get(context, optionalSizes, optionalType); + Type optionalType, Attribute optionalSparsity) { + return ValueTensorType::get(context, optionalSizes, optionalType, + optionalSparsity); }); } void ValueTensorType::print(AsmPrinter &printer) const { - printTensorType(printer, getOptionalSizes(), getOptionalDtype()); + printTensorType(printer, getOptionalSizes(), getOptionalDtype(), + getOptionalSparsity()); } Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { @@ -509,9 +562,9 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { // TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType // and AnyTorchType generate the exact same code (in TorchOps.cpp.inc). -// Unfortunately the generated implementations aren't visible/exposed ("static" linkage) -// and the predicates themselves can't be added/used in the specification of the parameters -// of the Torch_DictType. +// Unfortunately the generated implementations aren't visible/exposed ("static" +// linkage) and the predicates themselves can't be added/used in the +// specification of the parameters of the Torch_DictType. static bool isAnyTorchDictKeyType(Type type) { return type.isa() || type.isa() || type.isa() || type.isa() || diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a02465399a9c..06d36f58d1c8 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -122,7 +122,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.operator \"aten.ge\"(%arg0, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %0 = torch.operator \"aten.ge\"(%arg0, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %0 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -138,14 +138,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.operator \"aten.ge\"(%arg1, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %0 = torch.operator \"aten.ge\"(%arg1, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %0 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool\n" +" %1 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool \n" " torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -162,16 +162,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.operator \"aten.ne\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %0 = torch.operator \"aten.ne\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %0 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.operator \"aten.lt\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %1 = torch.operator \"aten.lt\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %1 -> () {\n" -" %6 = torch.operator \"aten.ge\"(%arg0, %arg1) : (!torch.union, !torch.union) -> !torch.bool\n" +" %6 = torch.operator \"aten.ge\"(%arg0, %arg1) : (!torch.union, !torch.union) -> !torch.bool \n" " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -180,7 +180,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield\n" " } else {\n" -" %6 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool\n" +" %6 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool \n" " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float\n" +" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" " %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" " %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float\n" +" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n" " %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" " %24 = torch.aten.append.t %1, %23 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield\n" @@ -6238,63 +6238,151 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: input must have at least two dimensions\"\n" +" %int2 = torch.constant.int 2\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg2, %2, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg3, %4, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %6 = torch.aten.ne.int %3, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.prim.ListConstruct %int9223372036854775807, %8 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.prim.min.self_int %9 : !torch.list -> !torch.int\n" +" torch.prim.Loop %10, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %19 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.eq.int %arg4, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %21 = torch.prim.If %20 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %22 = torch.aten.eq.int %arg4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %22 : !torch.bool\n" +" }\n" +" torch.prim.If %21 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %22 = torch.aten.append.t %7, %19 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %11 = torch.aten.__getitem__.t %arg0, %3 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg0, %5 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %12, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.prim.min.int %11, %13 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.prim.max.int %14, %int0 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %19 = torch.aten.__getitem__.t %arg0, %3 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %19, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %arg0, %5 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.prim.min.int %20, %21 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.prim.max.int %22, %int0 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %23 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %15 : !torch.int\n" +" }\n" +" %18 = torch.aten.append.t %7, %17 : !torch.list, !torch.int -> !torch.list\n" +" return %7 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.asinh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.sigmoid\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.cos\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.hardsigmoid\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.cosh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.softplus\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.square\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.acosh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.hardswish\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.silu\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.exp\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.expm1\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.atanh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.cos\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.sigmoid\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.hardsigmoid\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.softplus\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.square\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hardswish\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.silu\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.exp\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.expm1\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" @@ -6348,6 +6436,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.logit\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rsqrt\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6473,10 +6565,49 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dequantize.self\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dequantize.tensor\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.int_repr\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.grid_sampler\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n" @@ -6486,7 +6617,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.aten.lt.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -6494,7 +6625,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %3 = torch.aten.le.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.aten.lt.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -6662,6 +6793,57 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_cross\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"the size of first tensor ({}) must match the size of second tensor ({}) at dimension {}\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: inputs must have the same number of dimensions\"\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %5 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.format(%str_0, %10, %11, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %13 = torch.aten.add.str %str, %12 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %13, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6674,6 +6856,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isneginf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isposinf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6722,6 +6912,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.remainder.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fmod.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.floor_divide.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6750,6 +6948,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.gather\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6858,6 +7060,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.trace\"(%arg0: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: input must have rank 2\"\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6886,6 +7103,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" " %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" " %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" @@ -7100,7 +7322,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.If %2 -> (!torch.list) {\n" " %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" " %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n" -" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list\n" +" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list \n" " %8 = torch.aten.add.t %7, %arg1 : !torch.list, !torch.list -> !torch.list\n" " torch.prim.If.yield %8 : !torch.list\n" " } else {\n" @@ -7151,6 +7373,403 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__._max_pool3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__._max_pool3d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %int-4 = torch.constant.int -4\n" +" %int-5 = torch.constant.int -5\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: max_pool3d: dilation must be either a single int, or a tuple of three ints\"\n" +" %str_1 = torch.constant.str \"AssertionError: max_pool3d: padding must either be a single int, or a tuple of thee ints\"\n" +" %str_2 = torch.constant.str \"AssertionError: max_pool3d: stride must either be omitted, a single int, or a tuple of three ints\"\n" +" %none = torch.constant.none\n" +" %str_3 = torch.constant.str \"AssertionError: max_pool3d: kernel_size must either be a single int, or a tuple of three ints\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.tuple) {\n" +" %45 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" } else {\n" +" %45 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" }\n" +" %6:3 = torch.prim.TupleUnpack %5 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13:3 = torch.prim.If %12 -> (!torch.int, !torch.int, !torch.int) {\n" +" torch.prim.If.yield %6#0, %6#0, %6#0 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %45 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %47:3 = torch.prim.If %46 -> (!torch.int, !torch.int, !torch.int) {\n" +" %48 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %49 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %50 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %48, %49, %50 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %48 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %49 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %50 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %48, %49, %50 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" torch.prim.If.yield %47#0, %47#1, %47#2 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" %14 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.tuple) {\n" +" %45 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" } else {\n" +" %45 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg3, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" }\n" +" %20:3 = torch.prim.TupleUnpack %19 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %21 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.tuple) {\n" +" %45 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" } else {\n" +" %45 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg4, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" }\n" +" %27:3 = torch.prim.TupleUnpack %26 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %28 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %29 = torch.aten.eq.int %28, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %30 = torch.prim.If %29 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %30 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %31 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %32 = torch.aten.eq.int %31, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %33 = torch.prim.If %32 -> (!torch.int) {\n" +" %45 = torch.aten.__getitem__.t %arg0, %int-5 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %45 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %34 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %37 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %38 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%35, %6#0, %20#0, %13#0, %27#0, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %39 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%36, %6#1, %20#1, %13#1, %27#1, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %40 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%37, %6#2, %20#2, %13#2, %27#2, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %41 = call @__torch__._pool3d_shape_check(%arg0, %6#0, %6#1, %6#2, %13#0, %13#1, %13#2, %20#0, %20#1, %20#2, %27#0, %27#1, %27#2, %38, %39, %40) : (!torch.list, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none\n" +" %42 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %43 = torch.aten.eq.int %42, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %44 = torch.prim.If %43 -> (!torch.list) {\n" +" %45 = torch.prim.ListConstruct %34, %38, %39, %40 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %45 : !torch.list\n" +" } else {\n" +" %45 = torch.prim.ListConstruct %33, %34, %38, %39, %40 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %45 : !torch.list\n" +" }\n" +" return %44 : !torch.list\n" +" }\n" +" func.func @__torch__._pool3d_shape_check(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.int, %arg14: !torch.int, %arg15: !torch.int) -> !torch.none {\n" +" %str = torch.constant.str \"AssertionError: pool3d: input dimensions must be 4 or 5\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.gt.int %arg10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg12, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %20 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" %20 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.ne.int %20, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.bool) {\n" +" %25 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.ne.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %26 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %25 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.ne.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %26 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %24 = torch.prim.If %23 -> (!torch.bool) {\n" +" %25 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.ne.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %26 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %24 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %20 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.ne.int %20, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %24 = torch.prim.If %23 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %25 = torch.prim.If %24 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int4 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %25 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %13, %arg7 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" %20 = torch.aten.floordiv.int %arg3, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.ge.int %20, %arg9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %20 = torch.aten.floordiv.int %arg2, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.ge.int %20, %arg8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.ge.int %arg13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.bool) {\n" +" %20 = torch.aten.ge.int %arg15, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" %20 = torch.aten.ge.int %arg14, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %19 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %none : !torch.none\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" @@ -7476,6 +8095,73 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.adaptive_max_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @__torch__.adaptive_max_pool2d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %8 = torch.aten.sub.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %9 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %10 = torch.prim.TupleConstruct %6, %6 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %10 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.flatten.using_ints\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7593,6 +8279,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7633,6 +8322,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" " %1 = torch.derefine %arg1 : !torch.float to !torch.union\n" @@ -7685,6 +8377,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" " return %0 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linspace\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7882,11 +8578,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.lerp.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.lerp.Scalar\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.addcmul\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" @@ -8090,6 +8794,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.list to !torch.optional>\n" " %1 = torch.derefine %arg4 : !torch.list to !torch.optional>\n" @@ -8098,10 +8806,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0, %arg1, %arg2, %0, %1, %2, %arg6, %3) : (!torch.list, !torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.int, !torch.optional>) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_tbc\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.eq.int %6, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" +" %10 = torch.prim.ListConstruct %arg3 : (!torch.int) -> !torch.list\n" +" %11 = torch.prim.ListConstruct : () -> !torch.list\n" +" %12 = torch.prim.ListConstruct : () -> !torch.list\n" +" %13 = torch.derefine %arg2 : !torch.list to !torch.optional>\n" +" %14 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %13, %9, %10, %11, %false, %12, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" %15:3 = torch.prim.ListUnpack %14 : !torch.list -> !torch.int, !torch.int, !torch.int\n" +" %16 = torch.prim.ListConstruct %15#2, %15#0, %15#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %16 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list {\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8121,6 +8886,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.group_norm\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.native_group_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple, list, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %3 : !torch.tuple, list, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8348,10 +9128,137 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %7 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int4 = torch.constant.int 4\n" +" %int0 = torch.constant.int 0\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %14 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8499,7 +9406,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n" " %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n" " %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n" -" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int\n" +" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n" " return %2 : !torch.int\n" " }\n" " func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n" @@ -8527,6 +9434,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_norm\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg4 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.frobenius_norm.dim\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %0 = torch.derefine %arg1 : !torch.list to !torch.optional>\n" @@ -8534,6 +9446,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = torch.derefine %none : !torch.none to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %false, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %0 = torch.derefine %arg2 : !torch.list to !torch.optional>\n" @@ -8555,7 +9475,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" @@ -8597,6 +9517,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.acosh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -8612,12 +9542,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.asinh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" @@ -8662,6 +9597,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -8809,10 +9749,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool3d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.group_norm\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_group_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -8957,6 +9934,29 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.grid_sampler\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" +" %int2 = torch.constant.int 2\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -9144,6 +10144,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_cross\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" @@ -9167,12 +10192,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool3d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -9390,10 +10425,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.diagonal\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" @@ -9569,6 +10612,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -10063,24 +11152,118 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.prim.ListConstruct %int11 : (!torch.int) -> !torch.list\n" +" %10 = torch.aten.__contains__.int_list %9, %8 : !torch.list, !torch.int -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv1d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" %9 = torch.prim.ListConstruct %int11 : (!torch.int) -> !torch.list\n" -" %10 = torch.aten.__contains__.int_list %9, %8 : !torch.list, !torch.int -> !torch.bool\n" -" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" -" torch.prim.If %11 -> () {\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %8 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_tbc\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" @@ -10178,6 +11361,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv3d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10246,6 +11433,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -10347,6 +11544,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -10463,6 +11668,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.selu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -10472,6 +11699,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -10550,6 +11785,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -10809,6 +12048,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10952,6 +12203,84 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple\n" +" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int5 = torch.constant.int 5\n" +" %int8 = torch.constant.int 8\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" @@ -11303,6 +12632,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linspace\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int6 = torch.constant.int 6\n" @@ -11390,6 +12745,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tan\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -11416,6 +12782,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" @@ -11478,6 +12855,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.trace\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " return %int4 : !torch.int\n" @@ -11629,6 +13017,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" return %arg4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dequantize.self\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" return %int6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dequantize.tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" return %int6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %int12 = torch.constant.int 12\n" +" %int0 = torch.constant.int 0\n" +" %int13 = torch.constant.int 13\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int13 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int12 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int) -> !torch.int {\n" +" %int14 = torch.constant.int 14\n" +" %int12 = torch.constant.int 12\n" +" %int1 = torch.constant.int 1\n" +" %int13 = torch.constant.int 13\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int13 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n" +" %int14 = torch.constant.int 14\n" +" %int12 = torch.constant.int 12\n" +" %int1 = torch.constant.int 1\n" +" %int13 = torch.constant.int 13\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int13 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 30cc4db44181..2891a22eb817 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -81,7 +81,7 @@ class AdjustCallingConventionForFunc } newResultTypes.push_back(type); } - rewriter.updateRootInPlace(func, [&] { + rewriter.modifyOpInPlace(func, [&] { func.setType(FunctionType::get( getContext(), conversion.getConvertedTypes(), newResultTypes)); // Clear out the type bounds, now that the type incorporates them. @@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( - [](Torch::TupleType type, - SmallVectorImpl &types) -> LogicalResult { + [](Torch::TupleType type, SmallVectorImpl &types) -> LogicalResult { llvm::append_range(types, type.getContainedTypes()); return success(); }); typeConverter.addConversion( - [](Torch::NoneType type, - SmallVectorImpl &types) -> LogicalResult { + [](Torch::NoneType type, SmallVectorImpl &types) -> LogicalResult { return success(); }); diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 0f7621ff0da4..4def554d9f49 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -3,10 +3,12 @@ add_mlir_library(TorchMLIRTorchPasses DecomposeComplexOps.cpp DropAbstractInterpCalculations.cpp EraseModuleInitializer.cpp + FuseQuantizedOps.cpp Passes.cpp GlobalizeObjectGraph.cpp InlineGlobalSlots.cpp LowerToBackendContract.cpp + MatchQuantizedOps.cpp MaximizeValueSemantics.cpp PrepareForGlobalizeObjectGraph.cpp RecomposeComplexOps.cpp diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 281a827858b7..39d198c1dac7 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -71,8 +71,8 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, } Type resultType = tensorType.getWithSizesAndDtype( - sizes.size() == 0 ? std::optional>() - : llvm::ArrayRef(sizes), + !tensorType.hasSizes() ? std::optional>() + : llvm::ArrayRef(sizes), tensorType.getOptionalDtype()); return resultType; } @@ -126,35 +126,6 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc, return sub; } -// Helper to create a tensor filled with the given scalar. Scalar would be -// converted the to the element type of the given tensor type. -static Value createInitTensor(PatternRewriter &rewriter, Location loc, - BaseTensorType resultType, Value scalar, - Value sizeList) { - assert(resultType.hasDtype() && "result must have dtype"); - Value noneVal = rewriter.create(loc); - Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); - return rewriter.create(loc, resultType, sizeList, scalar, dtype, - /*layout=*/noneVal, - /*device=*/noneVal, - /*memory_format=*/noneVal); -} - -// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` -// would be converted to the element type of the given `inputType`. -static Value createRank0Tensor(PatternRewriter &rewriter, Location loc, - BaseTensorType inputType, Value scalar) { - assert(inputType.hasDtype() && "input must have dtype"); - SmallVector sizes; - BaseTensorType rank0TensorTy = - inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) - .cast(); - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), - ValueRange{}); - return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); -} - // Share code between `softmax_backward` and `log_softmax_backward` ops. // Returns x - y * sum(z, dim). static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, @@ -210,7 +181,7 @@ static bool parseEquation(const std::string &equation, inputToken.clear(); currentVariable = kIsResult; index++; - } else { + } else if (equation[index] != ' ') { return false; } index++; @@ -384,7 +355,7 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, auto rhsType = rhs.getType().cast(); Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() - : rhsType.getOptionalDtype(); + : rhsType.getOptionalDtype(); llvm::SmallDenseMap lhsDimShapeMap; for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { @@ -486,7 +457,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, return success(); } - static Value performLastReduceAndPermute(PatternRewriter &rewriter, Location loc, Type outType, Value input, @@ -723,6 +693,131 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposePrimTolistOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimTolistOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto self = op.getOperands()[0]; + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "Unknown self shape"); + + int64_t rank = selfTy.getSizes().size(); + if (rank != 1) + return rewriter.notifyMatchFailure(op, "Expected rank-1"); + + int64_t length = selfTy.getSizes().back(); + if (length == Torch::kUnknownSize) + return rewriter.notifyMatchFailure(op, "Tolist length is unknown"); + + auto resultTy = dyn_cast(op.getType(0)); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "Result type is not list"); + + auto scalarTy = resultTy.getContainedType(); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto extractTy = rewriter.getType( + llvm::SmallVector{1}, selfTy.getOptionalDtype()); + llvm::SmallVector results; + llvm::SmallVector sizes(selfTy.getSizes()); + for (int64_t i = 0; i < length; ++i) { + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + loc, extractTy, self, /*dim=*/zero, /*index=*/iv); + Value scalar = rewriter.create(loc, scalarTy, extract); + results.push_back(scalar); + } + + rewriter.replaceOpWithNewOp(op, resultTy, results); + return failure(); + } +}; +} // namespace + +namespace { +class DecomposeAtenSplitSizesOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSplitSizesOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim()); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenSplitWithSizesOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSplitWithSizesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value self = op.getSelf(); + SmallVector splitSizes; + if (!getListConstructElements(op.getSplitSizes(), splitSizes)) + return rewriter.notifyMatchFailure(op, "Unable to get sizes"); + + if (splitSizes.empty()) + return rewriter.notifyMatchFailure(op, "No split sizes"); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "Self shape unknown"); + + int64_t rank = selfTy.getSizes().size(); + auto resultTy = dyn_cast(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "Result type not a list"); + + auto sliceTy = + dyn_cast_or_null(resultTy.getContainedType()); + if (!isa(sliceTy)) + return rewriter.notifyMatchFailure(op, "Slice type is unknown"); + + int64_t dimInt = 0; + bool hasDim = matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)); + if (dimInt < 0) + dimInt += rank; + + auto intTy = rewriter.getType(); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value begin = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + + llvm::SmallVector slices; + llvm::SmallVector sliceSizes(sliceTy.getSizes()); + int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt]; + for (auto size : splitSizes) { + Value end = rewriter.create(loc, intTy, begin, size); + + int64_t sizeInt; + if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) { + sliceSizes[dimInt] = sizeInt; + } else if (hasDim) { + sliceSizes[dimInt] = defaultLength; + } + + sliceTy = rewriter.getType(sliceSizes, + sliceTy.getOptionalDtype()); + Value slice = rewriter.create( + loc, sliceTy, op.getSelf(), + /*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one); + slices.push_back(slice); + begin = end; + } + + rewriter.replaceOpWithNewOp(op, resultTy, slices); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNarrowOp : public OpRewritePattern { public: @@ -961,6 +1056,40 @@ class DecomposeAtenIsinfOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenIsneginfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsneginfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenIsposinfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsposinfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(f64Type, + APFloat::getInf(f64Type.getFloatSemantics()))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: @@ -1052,6 +1181,54 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { }; } // namespace +namespace { +// Calculate the trace of the input tensor as the sum over its diagonal +// elements. This computation is performed as: +// +// Step1: Obtain the diagonal using AtenDiagonalOp +// Step2: Compute the trace using AtenSumOp. +// +// It is verified that the input tensor has rank two. +class DecomposeAtenTraceOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTraceOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + std::optional inRank = getTensorRank(self); + if (inRank != 2) + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have rank 2."); + + Value none = rewriter.create(loc); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + BaseTensorType inputType = self.getType().cast(); + + Value output = op.getResult(); + BaseTensorType outputType = output.getType().cast(); + + ArrayRef inputShape = inputType.getSizes(); + int64_t diagonalSize = std::min(inputShape[0], inputShape[1]); + SmallVector diagonalShape{diagonalSize}; + Type elementType = inputType.getOptionalDtype(); + Type diagonalType = inputType.getWithSizesAndDtype( + llvm::ArrayRef(diagonalShape), elementType); + + Value diagonal = rewriter.create( + loc, diagonalType, /*input=*/self, /*offset=*/zero, /*dim1=*/zero, + /*dim2=*/one); + Value sum = rewriter.create(loc, outputType, /*self=*/diagonal, + /*dtype=*/none); + rewriter.replaceOp(op, sum); + return success(); + } +}; +} // namespace + // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). // To avoid overflow we use the following decomposition rule: @@ -1060,22 +1237,32 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { // softmax = unnorm / sum(unnorm, dim, keepdim = True) template static Value getSoftmaxResult(OpTy op, Value self, Type resultType, - PatternRewriter &rewriter) { + Type accumulatorType, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.getDim(); + if (resultType != accumulatorType) + self = convertTensorToDtype(rewriter, loc, self, accumulatorType); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); + if (!xMax) return nullptr; - Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax); + Value unNormalized = + createTensorSub(rewriter, loc, self.getType(), self, xMax); Value unNormalizedExp = - rewriter.create(loc, resultType, unNormalized); + rewriter.create(loc, self.getType(), unNormalized); Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim, /*keepDim=*/true); if (!sum) return nullptr; - return rewriter.create(loc, resultType, unNormalizedExp, - sum); + + Value result = rewriter.create(loc, self.getType(), + unNormalizedExp, sum); + if (resultType != accumulatorType) + result = convertTensorToDtype(rewriter, loc, result, + resultType.cast().getDtype()); + + return result; } // Decompose softmax into: exp(x) / sum(exp(x)) @@ -1107,7 +1294,10 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } - Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); + Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); + + Value result = getSoftmaxResult(op, self, resultTensorType, + accumulatorTensorType, rewriter); if (!result) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), @@ -1152,7 +1342,11 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } - Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); + + Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); + + Value result = getSoftmaxResult(op, self, resultTensorType, + accumulatorTensorType, rewriter); if (!result) return op.emitError("failed to get softmax result"); rewriter.replaceOpWithNewOp(op, resultTensorType, @@ -1264,7 +1458,59 @@ class DecomposeAten_LogSoftmaxBackwardDataOp }; } // namespace -// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp` +namespace { +class DecomposeAtenAMinMaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Torch::AtenAminOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector dimList; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { + return rewriter.notifyMatchFailure(op, "dims not foldable constants"); + } + + bool keepdim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { + return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); + } + + auto loc = op.getLoc(); + std::sort(dimList.begin(), dimList.end(), std::greater()); + + Value reduction = op.getSelf(); + auto resultTy = cast(op.getType()); + auto reductionTy = cast(reduction.getType()); + llvm::SmallVector reductionShape(reductionTy.getSizes()); + + for (auto dim : dimList) { + auto dimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + reductionShape[dim] = 1; + if (!keepdim) { + for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) + reductionShape[i] = reductionShape[i + 1]; + reductionShape.resize(reductionShape.size() - 1); + } + + reductionTy = rewriter.getType( + reductionShape, resultTy.getOptionalDtype()); + auto idxTy = rewriter.getType( + reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); + llvm::SmallVector types{reductionTy, idxTy}; + reduction = rewriter + .create(loc, types, reduction, + dimValue, op.getKeepdim()) + .getResult(0); + } + + rewriter.replaceOp(op, reduction); + return success(); + } +}; +} // namespace + +// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into +// `AtenMinDimOp` namespace { template class DecomposeAtenArgMinMaxOp : public OpRewritePattern { @@ -1295,9 +1541,9 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. - // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input - // tensor is flattened to 1d tensor and then the reduction happens on the - // 0th dimension. + // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so + // first the input tensor is flattened to 1d tensor and then the reduction + // happens on the 0th dimension. if (dim.getType().isa()) { BaseTensorType flattenType = inputType @@ -1312,11 +1558,11 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { } Value resultArg = - rewriter - .create(loc, valueTensorType, indicesTensorType, - input, dim, keepDim) - .getIndices(); - + rewriter + .create(loc, valueTensorType, indicesTensorType, input, + dim, keepDim) + .getIndices(); + rewriter.replaceOp(op, resultArg); return success(); } @@ -1577,6 +1823,117 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace +// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select, +// aten.add.Tensor and aten.mull.Tensor. See +// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70. +// def linalg_cross(self: Tensor, other: Tensor, dim: int = -1): +// broadcast_shape = compute_broadcast_shape(self, other) +// a = torch.broadcast_to(self, broadcast_shape) +// b = torch.broadcast_to(other, broadcast_shape) +// idx = torch.arange(3) +// return a.index_select(dim, (idx + 1) % 3) * +// b.index_select(dim, (idx + 2) % 3) - +// a.index_select(dim, (idx + 2) % 3) * +// b.index_select(dim, (idx + 1) % 3) +namespace { +class DecomposeAtenLinalgCrossOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinalgCrossOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + Type opType = op.getType(); + Value dim = op.getDim(); + + auto resType = self.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + Type dtype = resType.getDtype(); + if (dtype.isa()) { + return rewriter.notifyMatchFailure( + op, "lowering of aten.linalg_cross for complex inputs dtype is " + "currently unimplemented"); + } + + // calculate common shape for broadcast + SmallVector broadcastShape; + SmallVector broadcastShapeValue; + computeBroadcastShape(rewriter, loc, self, other, broadcastShape, + broadcastShapeValue); + + Type broadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(broadcastShape), dtype); + + Value indexBroadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + broadcastShapeValue); + + // broadcast tensors to common shape + auto a = rewriter.create(loc, broadcastType, self, + indexBroadcastShapeTorchList); + auto b = rewriter.create(loc, broadcastType, other, + indexBroadcastShapeTorchList); + + // create constants + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constTwo = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + Value constThree = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + Value none = rewriter.create(loc); + + // idx = torch.arange(3) + auto outType = opType.dyn_cast(); + auto arangeType = outType.getWithSizesAndDtype( + llvm::ArrayRef(3), + IntegerType::get(op.getContext(), 64, IntegerType::Signed)); + auto idx = rewriter.create( + loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + // (idx + 1) and (idx + 2) + auto idxPlusOne = rewriter.create(loc, arangeType, idx, + constOne, constOne); + auto idxPlusTwo = rewriter.create(loc, arangeType, idx, + constTwo, constOne); + + // (idx + 1) % 3 and (idx + 2) % 3 + auto idxPlusOneRemainderThree = rewriter.create( + loc, arangeType, idxPlusOne, constThree); + auto idxPlusTwoRemainderThree = rewriter.create( + loc, arangeType, idxPlusTwo, constThree); + + // a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3) + auto idxSelectAPlusOne = rewriter.create( + loc, opType, a, dim, idxPlusOneRemainderThree); + auto idxSelectBPlusTwo = rewriter.create( + loc, opType, b, dim, idxPlusTwoRemainderThree); + auto firstMul = rewriter.create( + loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo); + + // a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + auto idxSelectAPlusTwo = rewriter.create( + loc, opType, a, dim, idxPlusTwoRemainderThree); + auto idxSelectBPlusOne = rewriter.create( + loc, opType, b, dim, idxPlusOneRemainderThree); + auto secondMul = rewriter.create( + loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne); + + // subtract the results of the two multiplications from above + rewriter.replaceOpWithNewOp(op, opType, firstMul, + secondMul, constOne); + + return success(); + } +}; +} // namespace + // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // @@ -1890,6 +2247,35 @@ class DecomposeAtenLeakyReluBackwardOp }; } // namespace +namespace { +class DecomposeAtenLerpScalarOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLerpScalarOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resType = op.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = op.getSelf(); + auto inputType = start.getType().cast(); + + auto delta = rewriter.create(loc, inputType, op.getEnd(), + start, cstOne); + + auto weightedDelta = + rewriter.create(loc, inputType, delta, op.getWeight()); + auto lerp = rewriter.create(loc, inputType, start, + weightedDelta, cstOne); + rewriter.replaceOp(op, lerp); + return success(); + } +}; +} // namespace + // Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) namespace { class DecomposeAtenEluOp : public OpRewritePattern { @@ -1937,6 +2323,61 @@ class DecomposeAtenEluOp : public OpRewritePattern { }; } // namespace +// Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1))) +namespace { +class DecomposeAtenSeluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSeluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + auto resType = op.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + // Define λ and α + double scale = 1.0507009873554804934193349852946; + double alpha = 1.6732632423543772848170429916717; + + // Create constants for λ and α + Value scaleVal = rewriter.create( + loc, rewriter.getF64FloatAttr(scale)); + Value alphaVal = rewriter.create( + loc, rewriter.getF64FloatAttr(alpha)); + + // Create zero tensor for comparison + Value constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + + // Calculate positive and negative parts + Value constantOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, input); + Value minZeroX = + rewriter.create(loc, resType, zeroTensor, input); + Value expInput = rewriter.create(loc, resType, minZeroX); + Value expInputMinusOne = rewriter.create( + loc, resType, expInput, constantOne, constantOne); + Value negativeOutput = rewriter.create( + loc, resType, expInputMinusOne, alphaVal); + + // Multiply the result by λ + Value seluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOne); + seluOutput = + rewriter.create(loc, resType, seluOutput, scaleVal); + + // Replace the original operation + rewriter.replaceOp(op, seluOutput); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenTOp : public OpRewritePattern { public: @@ -2085,31 +2526,9 @@ class DecomposeAtenRollOp : public OpRewritePattern { }; } // namespace -// Decompose aten.repeat into aten.expand and aten.view ops. +// Decompose aten.repeat into aten.squeeze, aten.unsqueeze, and aten.broadcast. // // Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html -// -// For shape [S1, S2, S3] and repeats [M0, M1, M2, M3] -// MS0 = M0; MS1 = M1 * S1; MS2 = M2 * S2; MS3 = M3 * S3 -// -// def aten_repeat(self, repeats): -// sizes = self.size() -// unsqueezed_sizes = [] -// expanded_sizes = [] -// reshape_sizes = [] -// leading_rank = repeats.size() - sizes.size() -// for r in range(leading_rank): -// unsqueezed_sizes.append(1) -// expanded_sizes.append(repeats[r]) -// reshaped_sizes.append(repeats[r]) -// -// for s, m in zip(sizes, repeats[leading_rank:]): -// unsqueezed_sizes += [1, s] -// expanded_sizes += [m, s] -// reshaped_sizes += [m * s] -// return -// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) -// namespace { class DecomposeAtenRepeatOp : public OpRewritePattern { public: @@ -2118,94 +2537,110 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); - MLIRContext *context = op.getContext(); - std::optional maybeRank = getTensorRank(self); - if (!maybeRank) - return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); - unsigned rank = *maybeRank; + auto selfTy = cast(self.getType()); + if (!selfTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); SmallVector repeats; if (!getListConstructElements(op.getRepeats(), repeats)) return rewriter.notifyMatchFailure( op, "Unimplemented: repeats not list of Scalar"); - if (rank > repeats.size()) { + int64_t rank = selfTy.getSizes().size(); + if (rank > static_cast(repeats.size())) { return rewriter.notifyMatchFailure( op, "repeats are not matched with self's rank"); } - auto insertDimSizes = [](SmallVector &dimSizes, - SmallVector &shape, - const ArrayRef &vals) { - dimSizes.insert(dimSizes.end(), vals.begin(), vals.end()); - std::transform(vals.begin(), vals.end(), std::back_inserter(shape), - [&](Value val) -> int64_t { - int64_t cst_val; - if (matchPattern(val, m_TorchConstantInt(&cst_val))) { - return cst_val; - } else { - return kUnknownSize; - } - }); - }; + int64_t repeatSz = repeats.size(); + int64_t batch = repeatSz - rank; - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + if (!selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "input sizes unknown"); + + // Materialize out 1 dimensions to broadcast along. This includes + // materializing out preceding batch dimensions: + for (int i = 0; i < repeatSz; ++i) { + auto oldSizes = selfTy.getSizes(); + llvm::SmallVector sizes; + int64_t squeezeDim = i < batch ? i : i * 2 - batch; + + for (int j = 0; j < squeezeDim; ++j) + sizes.push_back(oldSizes[j]); + sizes.push_back(1); + for (int j = squeezeDim, s = oldSizes.size(); j < s; j++) + sizes.push_back(oldSizes[j]); - SmallVector unsqueezedSizes, expandedSizes, reshapedSizes; - SmallVector unsqueezedIntSizes, expandedIntSizes; - assert(repeats.size() >= rank && "leadingRank should greater than 0"); - auto leadingRank = repeats.size() - rank; - for (size_t i = 0; i < leadingRank; ++i) { - insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one}); - insertDimSizes(expandedSizes, expandedIntSizes, - ArrayRef{repeats[i]}); - reshapedSizes.push_back(repeats[i]); + Value dim = rewriter.create(loc, squeezeDim); + selfTy = + rewriter.getType(sizes, selfTy.getOptionalDtype()); + self = rewriter.create(loc, selfTy, self, dim); } - auto selfType = self.getType().dyn_cast(); - auto selfShape = selfType.getSizes(); - for (unsigned i = 0; i < rank; i++) { - auto scale = repeats[i + leadingRank]; - Value dimSize; - if (selfShape[i] == kUnknownSize) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - dimSize = rewriter.create(loc, self, dim); - } else { - dimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(selfShape[i])); + llvm::SmallVector lengths; + for (int i = 0; i < repeatSz; ++i) { + if (i < batch) { + lengths.push_back(repeats[i]); + continue; } - insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, - ArrayRef{one, dimSize}); - insertDimSizes(expandedSizes, expandedIntSizes, - ArrayRef{scale, dimSize}); + Value iv = rewriter.create( + loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch)); + Value dim = rewriter.create(loc, self, /*dim=*/iv); + lengths.push_back(repeats[i]); + lengths.push_back(dim); + } + + Value lengthv = rewriter.create( + loc, ListType::get(rewriter.getType()), lengths); - Value scaledSize = rewriter.create(loc, dimSize, scale); - reshapedSizes.push_back(scaledSize); + llvm::SmallVector expandShape(selfTy.getSizes()); + for (int i = 0; i < repeatSz; ++i) { + int64_t repeatDim = i < batch ? i : i * 2 - batch; + int64_t repeat; + if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat))) + repeat = Torch::kUnknownSize; + expandShape[repeatDim] = repeat; } - Type dtype = self.getType().cast().getOptionalDtype(); - Type unsqueezedType = ValueTensorType::get( - context, llvm::ArrayRef(unsqueezedIntSizes), dtype); - Type expandedType = - ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype); + auto mulDim = [](int64_t lhs, int64_t rhs) { + if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize) + return Torch::kUnknownSize; + return lhs * rhs; + }; - auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); - Value unsqueezedDims = - rewriter.create(loc, listType, unsqueezedSizes); - Value expandedDims = - rewriter.create(loc, listType, expandedSizes); - Value reshapedDims = - rewriter.create(loc, listType, reshapedSizes); - auto reshaped = rewriter.create(loc, unsqueezedType, - op.getSelf(), unsqueezedDims); - auto expanded = rewriter.create(loc, expandedType, - reshaped, expandedDims); + BaseTensorType expandTy = rewriter.getType( + expandShape, selfTy.getOptionalDtype()); + Value expand = + rewriter.create(loc, expandTy, self, lengthv); + + for (int i = 0; i < rank; ++i) { + auto oldShape = expandTy.getSizes(); + llvm::SmallVector newShape; + int64_t flattenDim = i + batch; + for (int j = 0; j < flattenDim; ++j) + newShape.push_back(oldShape[j]); + newShape.push_back( + mulDim(oldShape[flattenDim], oldShape[flattenDim + 1])); + for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j) + newShape.push_back(oldShape[j]); + + expandTy = rewriter.getType(newShape, + expandTy.getOptionalDtype()); + + // Used to keep the return type the same on the last flatten: + expandTy = i < rank - 1 ? expandTy : cast(op.getType()); + + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(flattenDim)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(flattenDim + 1)); + expand = rewriter.create(loc, expandTy, expand, + start, end); + } - rewriter.replaceOpWithNewOp(op, op.getType(), expanded, - reshapedDims); + rewriter.replaceOp(op, expand); return success(); } }; @@ -2451,6 +2886,49 @@ class DecomposeAtenWhereScalarSelfOp }; } // namespace +namespace { +class DecomposeAtenNanToNumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNanToNumOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + mlir::FloatType f64Type = rewriter.getF64Type(); + Value nan = op.getNan(); + Value posinf = op.getPosinf(); + Value neginf = op.getNeginf(); + auto baseType = + ValueTensorType::getWithLeastStaticInformation(op.getContext()); + if (dyn_cast_or_null(nan.getDefiningOp())) + nan = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(posinf.getDefiningOp())) + posinf = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(neginf.getDefiningOp())) + neginf = rewriter.create( + loc, + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + Value isNan = + rewriter.create(loc, baseType, op.getSelf()); + Value where = rewriter.create( + loc, baseType, isNan, nan, op.getSelf()); + Value isposinf = + rewriter.create(loc, baseType, where); + where = rewriter.create( + loc, baseType, isposinf, posinf, where); + Value isneginf = + rewriter.create(loc, baseType, where); + rewriter.replaceOpWithNewOp( + op, op.getType(), isneginf, neginf, where); + return success(); + } +}; +} // namespace + // Decompose aten.masked_fill.Scalar into aten.where.self op. namespace { class DecomposeAtenMaskedFillScalarOp @@ -2473,32 +2951,162 @@ class DecomposeAtenMaskedFillScalarOp }; } // namespace -// Decompose aten._convolution-like to aten.convolution +// Decompose aten._convolution-like to aten.convolution +namespace { +template +class DecomposeAten_ConvolutionLikeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConvolutionLikeOp op, + PatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), + op.getOutputPadding(), op.getGroups()); + + return success(); + } +}; +} // namespace + +namespace { + +static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, + Location loc, Value input, + int64_t dimA, + int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(input.getType().cast(), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); +} + +class DecomposeAtenConvTbcOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTbcOp op, + PatternRewriter &rewriter) const override { + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + Value oneList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1))}); + Value padding = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{op.getPad()}); + Value groups = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1)); + + // convtbc has WNC layout for input and output + // and WCF layout for weight + // whereas Convolution is going to use Conv1DNcwFcwOp for 1d + // which means we need the inputs in NCW and the weight in FCW + Value selfWnc = op.getSelf(); + Value selfNwc; + Value selfNcw; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, + 0, 1, selfNwc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose input to Nwc"); + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, + 1, 2, selfNcw))) + return rewriter.notifyMatchFailure(op, + "failed to transpose input to Ncw"); + + Value weightWcf = op.getWeight(); + Value weightFcw; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + weightWcf, 0, 2, weightFcw))) + return rewriter.notifyMatchFailure(op, + "failed to transpose weight to Fcw"); + + Value outputNcw = rewriter.create( + op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), + /*stride*/ oneList, + /*padding*/ padding, /*dilation*/ oneList, + /*transpose*/ cstFalse, /*output_padding*/ emptyList, groups); + + // convert output from Ncw to Wnc + Value outputNwc; + Value outputWnc; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + outputNcw, 1, 2, outputNwc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose output to Nwc"); + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + outputNwc, 0, 1, outputWnc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose output to Wnc"); + rewriter.replaceOp(op, outputWnc); + + return success(); + } +}; +} // namespace + +// Decompose aten.conv1d to aten.convolution +namespace { +class DecomposeAtenConv1dOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConv1dOp op, + PatternRewriter &rewriter) const override { + + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, + op.getGroups()); + + return success(); + } +}; +} // namespace + +// Decompose aten.conv2d to aten.convolution namespace { -template -class DecomposeAten_ConvolutionLikeOp - : public OpRewritePattern { +class DecomposeAtenConv2dOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConvolutionLikeOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConv2dOp op, PatternRewriter &rewriter) const override { + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), - op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), - op.getOutputPadding(), op.getGroups()); + op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, + op.getGroups()); return success(); } }; } // namespace -// Decompose aten.conv2d to aten.convolution +// Decompose aten.conv3d to aten.convolution namespace { -class DecomposeAtenConv2dOp : public OpRewritePattern { +class DecomposeAtenConv3dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConv2dOp op, + LogicalResult matchAndRewrite(AtenConv3dOp op, PatternRewriter &rewriter) const override { Value emptyList = rewriter.create( @@ -2534,19 +3142,6 @@ class DecomposeAtenConvTranspose2dOp }; } // namespace -static LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, - int64_t dimB, Type &transposedType) { - if (!inType.hasSizes()) - return failure(); - SmallVector shape(inType.getSizes()); - int64_t tmp = shape[0]; - shape[0] = shape[1]; - shape[1] = tmp; - transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape), - inType.getOptionalDtype()); - return success(); -} - // The convolution backward op is decomposed as follows: // inputH, inputW = input.shape[2:] // output_padding_ = [ @@ -3471,7 +4066,7 @@ class DecomposeAtenBernoulliOp : public OpRewritePattern { Value input = op.getSelf(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( - op, "The generator has to ben None because only global default " + op, "The generator has to be None because only global default " "generator is supported"); Value output; if (failed( @@ -3497,7 +4092,7 @@ class DecomposeAtenBernoulliLikeOp : public OpRewritePattern { Value p = op.getP(); if (!op.getGenerator().getType().template isa()) return rewriter.notifyMatchFailure( - op, "The generator has to ben None because only global default " + op, "The generator has to be None because only global default " "generator is supported"); auto inputType = input.getType().cast(); @@ -3529,7 +4124,7 @@ class DecomposeAtenBernoulliTensorOp Value prob = op.getP(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( - op, "The generator has to ben None because only global default " + op, "The generator has to be None because only global default " "generator is supported"); Value output; if (failed( @@ -3543,6 +4138,80 @@ class DecomposeAtenBernoulliTensorOp } // namespace namespace { +// Decompose exponential() to do inverse transform sampling. +// - https://en.wikipedia.org/wiki/Inverse_transform_sampling +// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus, +// exponential() = - ln(1 - uniform(0, 1)) / lambda. +class DecomposeAtenExponentialOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenExponentialOp op, + PatternRewriter &rewriter) const override { + if (!op.getGenerator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + Location loc = op.getLoc(); + Type resultType = op.getType(); + + // Create a uniform random op with low and high set to 0.0 and 1.0, + // respectively. + Value none = rewriter.create(loc); + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value one = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value emptyTensor = rewriter.create( + loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + Value x = rewriter.create(loc, resultType, emptyTensor, + /*from=*/zero, /*to=*/one, + /*generator=*/none); + + Value negX = rewriter.create(loc, resultType, x); + Value oneMinusX = + rewriter.create(loc, resultType, negX, one, + /*alpha=*/one); + Value lnOneMinusX = rewriter.create(loc, resultType, oneMinusX); + Value negLambda = rewriter.create(loc, op.getLambd()); + rewriter.replaceOpWithNewOp(op, resultType, lnOneMinusX, + negLambda); + return success(); + } +}; + +// aten.normal_functional(mean, sigma) = randn() * sigma + mean. +class DecomposeAtenNormalFunctionalOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNormalFunctionalOp op, + PatternRewriter &rewriter) const override { + if (!op.getGenerator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + Location loc = op.getLoc(); + Type resultType = op.getType(); + Value std = op.getStd(); + Value mean = op.getMean(); + + Value none = rewriter.create(loc); + Value one = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value randN = rewriter.create( + loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + Value stdRandN = + rewriter.create(loc, resultType, randN, std); + rewriter.replaceOpWithNewOp(op, resultType, stdRandN, mean, + /*alpha=*/one); + return success(); + } +}; + template class DecomposeAtenAddCLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3591,6 +4260,143 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenInstanceNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenInstanceNormOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.getInput().getType().cast(); + int64_t inputRank = inputTy.getSizes().size(); + SmallVector reducedShape(inputTy.getSizes()); + SmallVector reduceDimInts; + SmallVector reduceDimVals; + for (int i = 2; i < inputRank; ++i) { + reducedShape[i] = 1; + reduceDimVals.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } + + Type dtype = inputTy.getOptionalDtype(); + Type reducedTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(reducedShape), dtype); + + auto sizeListType = ListType::get(IntType::get(context)); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + // mean(x) + Value inputMean = rewriter.create( + loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); + + // x - mean(x) + Value inputMeanExpanded = + rewriter.create(loc, inputTy, inputMean, op.getInput()); + Value inputSubMean = rewriter.create( + loc, inputTy, op.getInput(), inputMeanExpanded, one); + // (x - mean(x))^2 + Value inputSubMeanSquare = rewriter.create( + loc, inputTy, inputSubMean, inputSubMean); + + Value variancesum = rewriter.create( + loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, + /*dtype=*/none); + + int64_t elemCount = 1; + for (int i = 2; i < inputRank; ++i) + elemCount *= inputTy.getSizes()[i]; + + Value hw = rewriter.create( + loc, rewriter.getI64IntegerAttr(elemCount)); + Value inputVar = + rewriter.create(loc, reducedTy, variancesum, hw); + + // rsqrt(var(x) + eps) + Value inputVarPlusEps = rewriter.create( + loc, reducedTy, inputVar, op.getEps(), one); + Value inputRsqrtVar = + rewriter.create(loc, reducedTy, inputVarPlusEps); + + // (x - mean(x)) * rsqrt(var(x) + eps) + Value inputRsqrtVarExpanded = rewriter.create( + loc, inputTy, inputRsqrtVar, op.getInput()); + Value inputNormalized = rewriter.create( + loc, inputTy, inputSubMean, inputRsqrtVarExpanded); + Value out = rewriter.create( + loc, op.getResult().getType(), inputNormalized); + + Value weight = op.getWeight(); + auto weightTy = weight.getType().cast(); + dtype = weightTy.getOptionalDtype(); + + SmallVector weightShape(weightTy.getSizes()); + SmallVector newWeightShape; + newWeightShape.push_back(1); + newWeightShape.append(weightShape); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Type newWeightTy = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, zero); + + while (static_cast(newWeightShape.size()) < inputRank) { + Value i = rewriter.create( + loc, rewriter.getI64IntegerAttr(newWeightShape.size())); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, i); + } + + Value weightExpanded = + rewriter.create(loc, inputTy, weight, op.getInput()); + + Value bias = op.getBias(); + auto biasTy = bias.getType().cast(); + dtype = biasTy.getOptionalDtype(); + + SmallVector biasShape(biasTy.getSizes()); + SmallVector newBiasShape; + newBiasShape.push_back(1); + newBiasShape.append(biasShape); + + Type newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, zero); + + while (static_cast(newBiasShape.size()) < inputRank) { + Value i = rewriter.create( + loc, rewriter.getI64IntegerAttr(newBiasShape.size())); + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, i); + } + + Value biasExpanded = + rewriter.create(loc, inputTy, bias, op.getInput()); + + out = rewriter.create(loc, out.getType(), out, + weightExpanded); + out = rewriter.create(loc, out.getType(), out, + biasExpanded, one); + + rewriter.replaceOp(op, out); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeLayerNormOp : public OpRewritePattern { @@ -3753,6 +4559,165 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenGroupNormOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenGroupNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getInput(); + Value weight = op.getWeight(); + Value bias = op.getBias(); + Value numGroups = op.getNumGroups(); + Value eps = op.getEps(); + + Value cstZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + + Value N = rewriter.create(loc, input, cstZero); + Value C = rewriter.create(loc, input, cstOne); + Value numElements = rewriter.create(loc, input); + Value numElementsDivN = + rewriter.create(loc, numElements, N); + Value HxW = rewriter.create(loc, numElementsDivN, C); + + AtenNativeGroupNormOp newOp = rewriter.create( + loc, ArrayRef{op.getResult().getType(), baseType, baseType}, + input, weight, bias, N, C, HxW, numGroups, eps); + + rewriter.replaceOp(op, newOp.getResult0()); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenNativeGroupNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeGroupNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getInput(); + Value weight = op.getWeight(); + Value bias = op.getBias(); + Value numGroups = op.getGroup(); + Value eps = op.getEps(); + + // Check the rank of the input/outputs tensor. + auto inputType = input.getType().cast(); + auto outputType = op.getResult0().getType().cast(); + auto meanType = op.getResult1().getType().cast(); + auto rsqrtVarType = op.getResult2().getType().cast(); + if (!inputType.hasSizes() || !outputType.hasSizes() || + !meanType.hasSizes() || !rsqrtVarType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "input/outputs tensor should have known sizes."); + } + + Value none = rewriter.create(loc); + Value cstZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value cstNegtiveOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value cstTrue = rewriter.create(loc, true); + Value cstFalse = rewriter.create(loc, false); + auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + + // GroupNorm requires the channel dimension (C) to be exactly divisible by + // the number of groups. + Value channel = rewriter.create(loc, input, cstOne); + Value remainder = + rewriter.create(loc, channel, numGroups); + Value eqOrNot = rewriter.create(loc, remainder, cstZero); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("the number of channels must be divisible by " + "the number of groups")); + + // Reshape the input tensor to (N, numGroups, -1) to apply normalization. + SmallVector newShape; + newShape.push_back(rewriter.create(loc, input, cstZero)); + newShape.push_back(numGroups); + newShape.push_back(cstNegtiveOne); + Value reshapedInput = rewriter.create( + loc, baseType, input, + rewriter.create( + loc, Torch::ListType::get(IntType::get(context)), newShape)); + + // Now we proceed with the normalization steps across the 'groupSize' + // Compute the mean and variance for each group + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + ArrayRef{cstNegtiveOne}); + auto mean = rewriter.create( + loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue, + /*dtype=*/none); + auto var = rewriter.create( + loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse, + /*keepdim=*/cstTrue); + + // Compute the normalized output: (input - mean) * rsqrt(var + eps) + auto varPlusEps = rewriter.create(loc, baseType, var, eps, + /*alpha=*/cstOne); + auto invStd = rewriter.create(loc, baseType, varPlusEps); + auto inputSubMean = rewriter.create( + loc, baseType, reshapedInput, mean, /*alpha=*/cstOne); + auto normalizedOutput = + rewriter.create(loc, baseType, inputSubMean, invStd); + + // Reshape normalized output back to the original input shape + auto inputShape = rewriter.create( + loc, Torch::ListType::get(IntType::get(context)), input); + auto reshapedOutput = rewriter.create( + loc, inputType, normalizedOutput, /*shape=*/inputShape); + + // Apply weight and bias if they are not None + // Reshape weight and bias to C,1,1,... + SmallVector viewShape = {channel}; + for (unsigned i = 2; i < inputType.getSizes().size(); i++) { + viewShape.push_back(cstOne); + } + Value viewShapeSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), viewShape); + + Value groupNormOutput = reshapedOutput; + if (!weight.getType().isa()) { + auto weightReshaped = rewriter.create( + loc, baseType, weight, /*shape=*/viewShapeSizeList); + groupNormOutput = rewriter.create( + loc, inputType, groupNormOutput, weightReshaped); + } + if (!bias.getType().isa()) { + auto biasReshaped = rewriter.create( + loc, baseType, bias, /*shape=*/viewShapeSizeList); + groupNormOutput = rewriter.create( + loc, inputType, groupNormOutput, biasReshaped, + /*alpha=*/cstOne); + } + + Value squeezedMean = + rewriter.create(loc, meanType, mean, cstNegtiveOne); + Value squeezedRsqrtVar = rewriter.create( + loc, rsqrtVarType, invStd, cstNegtiveOne); + + rewriter.replaceOp( + op, ArrayRef{groupNormOutput, squeezedMean, squeezedRsqrtVar}); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeBatchNormOp : public OpRewritePattern { @@ -4984,8 +5949,8 @@ class DecomposeAten_EmbeddingBagOp auto resultType2 = op->getResult(2).getType(); auto resultType3 = op->getResult(3).getType(); - mlir::TypeRange returnTypes{resultType0, resultType1, resultType2, - resultType3}; + llvm::SmallVector returnTypes{resultType0, resultType1, resultType2, + resultType3}; rewriter.replaceOpWithNewOp( op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode, @@ -5383,6 +6348,78 @@ class DecomposeAtenRandOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenLinspaceOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinspaceOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = getContext(); + + auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + Value none = rewriter.create(loc); + Value falseVal = rewriter.create(loc, false); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + + Value addStart; + int64_t steps; + if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) { + // specically handle steps == 1 + Value arange = rewriter.create( + loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), + op.getDevice(), op.getPinMemory()); + addStart = rewriter.create(loc, baseType, arange, + op.getStart(), one); + } else { + // handle steps != 1 or dynamic steps + Value neOrNot = rewriter.create(loc, op.getSteps(), one); + rewriter.create( + loc, neOrNot, + rewriter.getStringAttr("linspace's dynamic steps must not be 1")); + // create arange: [0, ..., steps - 1] + Value arange = rewriter.create( + loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), + op.getDevice(), op.getPinMemory()); + // calculate (end - start) / (steps - 1) + Value sub; + if (op.getEnd().getType().isa() || + op.getStart().getType().isa()) { + sub = rewriter.create(loc, Torch::FloatType::get(context), + op.getEnd(), op.getStart()); + } else { + sub = rewriter.create(loc, op.getEnd(), op.getStart()); + } + Value div = rewriter.create( + loc, sub, rewriter.create(loc, op.getSteps(), one)); + // calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start + Value mulScalar = + rewriter.create(loc, baseType, arange, div); + addStart = rewriter.create(loc, baseType, mulScalar, + op.getStart(), one); + } + // to dtype + Value result; + if (!op.getDtype().getType().isa()) { + result = rewriter.create( + loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal, + /*copy=*/falseVal, /*memory_format=*/none); + } else { + Value f32Type = rewriter.create( + loc, (int)torch_upstream::ScalarType::Float); + result = rewriter.create( + loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal, + /*copy=*/falseVal, /*memory_format=*/none); + } + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenVarMeanOp : public OpRewritePattern { public: @@ -6260,6 +7297,37 @@ class DecomposeAtenReshapeAsOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose AtenLinalgNormOp to AtenLinalgVectorNormOp only +class DecomposeAtenLinalgNormOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinalgNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector dimList; + if (!getListConstructElements(op.getDim(), dimList)) { + return rewriter.notifyMatchFailure( + op, "dim should comes from a PrimListConstructOp"); + } + if (dimList.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Unimplemented: only dim size of 1 is supported"); + } + + // default ord value is 2 for vector_norm + auto ord = op.getOrd(); + if (ord.getType().isa()) { + ord = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), + op.getDtype()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -6314,6 +7382,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6323,9 +7392,11 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( @@ -6334,8 +7405,11 @@ class DecomposeComplexOpsPass DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>(patterns); @@ -6343,12 +7417,14 @@ class DecomposeComplexOpsPass DecomposeAten_ConvolutionLikeOp>( patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal>(patterns); - addPatternIfTargetOpIsIllegal>(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenArgMinMaxOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6361,15 +7437,19 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeAtenBernoulliLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6413,6 +7493,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6430,13 +7512,17 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6460,6 +7546,12 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + // More specific conv ops + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp new file mode 100644 index 000000000000..6bc8a8ba084a --- /dev/null +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -0,0 +1,241 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +template +class QuantizeOperands : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector operands(op->getOperands()); + + bool dequanted = false; + auto f = [&dequanted](Value operand) { + if (auto dequant = operand.getDefiningOp()) { + operand = dequant.getOperand(); + dequanted = true; + } + if (auto dequant = operand.getDefiningOp()) { + operand = dequant.getOperand(); + dequanted = true; + } + return operand; + }; + + operands[0] = f(operands[0]); + operands[1] = f(operands[1]); + + if (!dequanted) { + return rewriter.notifyMatchFailure(op, "no dequantizations found"); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), operands); + return success(); + } +}; + +template class QuantizeBias : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector operands(op->getOperands()); + if (operands.size() < 3) + return failure(); + + Value lhsScale; + if (auto qLhs = + operands[0].getDefiningOp()) + lhsScale = qLhs.getScale(); + + Value rhsScale; + if (auto qRhs = + operands[1].getDefiningOp()) + rhsScale = qRhs.getScale(); + + if (!rhsScale || !lhsScale) + return failure(); + + auto resultTy = cast(op.getType()); + if (!isa(resultTy.getDtype())) + return failure(); + + Value bias = operands[2]; + auto biasTy = bias.getType().dyn_cast(); + + if (biasTy) { + auto biasETy = biasTy.getOptionalDtype(); + if (!biasETy || !isa(biasETy)) + return failure(); + } + + Value biasScale = rewriter.create( + op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); + + Value zero = rewriter.create( + op.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + auto qi32Ty = rewriter.getType(); + + if (biasTy) { + auto newBiasTy = + rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); + Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); + bias = rewriter.create( + op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); + bias = rewriter.create( + op.getLoc(), + rewriter.getType( + biasTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)), + bias); + operands[2] = bias; + } + + auto convTy = rewriter.getType( + resultTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)); + auto conv = rewriter.create(op.getLoc(), convTy, operands); + + auto convQTy = + rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); + auto makeOut = rewriter.create( + op.getLoc(), convQTy, conv, biasScale, zero); + rewriter.replaceOpWithNewOp(op, op.getType(), + makeOut); + + return success(); + } +}; + +template +class QuantizeAccumulator : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + auto lhs = op.getOperand(0); + auto rhs = op.getOperand(1); + + auto resultTy = dyn_cast_or_null(op.getType()); + if (!resultTy || !resultTy.hasDtype()) + return failure(); + + Type resultETy = resultTy.getDtype(); + if (!resultETy.isa()) + return failure(); + + Value lhsScale; + if (auto defining = + lhs.template getDefiningOp()) { + lhsScale = defining.getScale(); + } + + Value rhsScale; + if (auto defining = + rhs.template getDefiningOp()) { + rhsScale = defining.getScale(); + } + + if (!lhsScale || !rhsScale) + return failure(); + + // Quantize the bias input to the expected result: + Value zero = rewriter.create( + op.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + auto qi32Ty = rewriter.getType(); + Value biasScale = rewriter.create( + op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); + + // Update the quantied type: + llvm::SmallVector operands(op.getOperands()); + + auto newResultTy = + rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); + auto conv = rewriter.create(op.getLoc(), newResultTy, operands); + + // Attach the quantize information to the resulting qint32: + auto intReprTy = rewriter.getType( + resultTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)); + auto intRepr = rewriter.create(op.getLoc(), intReprTy, conv); + + auto quantTy = + rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); + auto quant = rewriter.create( + op.getLoc(), quantTy, intRepr, biasScale, zero); + auto dequant = + rewriter.create(op.getLoc(), resultTy, quant); + rewriter.replaceOp(op, dequant); + + return success(); + } +}; + +template class RemoveUnused : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + if (result.use_empty()) { + op.erase(); + return success(); + } + return failure(); + } +}; + +class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns + .insert, + RemoveUnused, + RemoveUnused, + QuantizeOperands, QuantizeOperands, + QuantizeAccumulator, QuantizeBias>( + context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createFuseQuantizedOpsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index da8be9b17e0b..239960629797 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -170,8 +170,8 @@ class ObjectGraphInfo { auto attr = std::get<1>(t); nameStack.push_back(attr.getName().str()); if (attr.getType().isa()) { - if (failed( - recursivelyTraverse(slot.getValue().getDefiningOp()))) + if (failed(recursivelyTraverse( + slot.getValue().getDefiningOp()))) return failure(); } else if (usedSlots.find(slot) != usedSlots.end()) { // Only create the GlobalSlotOp if the slot is used at all. @@ -190,8 +190,8 @@ class ObjectGraphInfo { } for (auto method : classType.getOps()) { nameStack.push_back(method.getName().str()); - funcLinkageInfo[{nnModule, - symbolTable.lookup(method.getFunction())}] = + funcLinkageInfo[{ + nnModule, symbolTable.lookup(method.getFunction())}] = LinkageInfo{llvm::join(nameStack, "."), method.getIsPrivate()}; nameStack.pop_back(); } @@ -501,21 +501,24 @@ static LogicalResult rewriteMonomorphizedFuncClone( SmallVector toErase; auto handlePrimSetAttr = [&](PrimSetAttrOp op) { - auto instance = mapping.lookup(op.getReceiver()).getDefiningOp(); + auto instance = + mapping.lookup(op.getReceiver()).getDefiningOp(); SlotOp affectedSlot; for (auto slot : instance.getOps()) { if (slot.getName() == op.getName()) affectedSlot = slot; } OpBuilder(op).create( - op.getLoc(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(), + op.getLoc(), + objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(), op.getValue()); toErase.push_back(op); return WalkResult::advance(); }; auto handlePrimGetAttr = [&](PrimGetAttrOp op) { if (!op.getType().isa()) { - auto instance = mapping.lookup(op.getReceiver()).getDefiningOp(); + auto instance = + mapping.lookup(op.getReceiver()).getDefiningOp(); SlotOp affectedSlot; for (auto slot : instance.getOps()) { if (slot.getName() == op.getName()) diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 76b57fe8c9a3..1e8c90deac4e 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -95,7 +95,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) { class InlineGlobalSlotsAnalysisState : public AnalysisState { public: InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) { - setSafe(); + (void)setSafe(); } void print(raw_ostream &os) const override { @@ -163,7 +163,8 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { } if (auto globalSlotSet = dyn_cast(op)) { auto *state = getOrCreate( - getProgramPoint(globalSlotSet.getSlotAttr())); + getProgramPoint( + globalSlotSet.getSlotAttr())); propagateIfChanged(state, state->setSafe(false)); } // Save the InitializeGlobalSlotsOp for later referencee @@ -211,8 +212,8 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { auto it = llvm::find(initializeGlobalSlotsOp.getSlotSymNames(), static_cast(flatSymbolRefPoint->getValue())); - Value value = initializeGlobalSlotsOp->getOperand( - std::distance(initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); + Value value = initializeGlobalSlotsOp->getOperand(std::distance( + initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); auto *flatSymbolRefState = getOrCreateFor(value, flatSymbolRefPoint); @@ -331,7 +332,8 @@ class InlineGlobalSlotsPass DenseSet safeToInline; for (int i = 0, e = initialize->getNumOperands(); i != e; i++) { - auto slotSymName = initialize.getSlotSymNames()[i].cast(); + auto slotSymName = + initialize.getSlotSymNames()[i].cast(); Value operand = initialize.getOperand(i); auto symbolRefPoint = solver.getProgramPoint( initialize.getSlotSymNames()[i].cast()); @@ -405,7 +407,8 @@ class InlineGlobalSlotsPass SmallVector newSlotSymNames; SmallVector newInitialValues; for (int i = 0, e = initialize.getNumOperands(); i != e; i++) { - auto slotSymName = initialize.getSlotSymNames()[i].cast(); + auto slotSymName = + initialize.getSlotSymNames()[i].cast(); if (!safeToInline.count(slotSymName)) { newSlotSymNames.push_back(slotSymName); newInitialValues.push_back(initialize.getOperand(i)); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index b0cd84ff6bfd..d7aa37ca7baf 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -202,15 +202,16 @@ static bool satisfiesBackendContract(ModuleOp module, // Check for unimplemented operators first to give more direct diagnostics. walkResult0 = module.walk([&](Torch::OperatorOp op) { if (llvm::all_of(op.getResults(), [&op](auto res) { - return succeeded( - checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false)); + return succeeded(checkType(op.getOperation(), res.getType(), + /*actuallyEmitDiagnostics=*/false)); })) { return WalkResult::advance(); } if (actuallyEmitDiagnostics) { - op->emitError("unsupported by backend contract: Unimplemented operator '" - + op.getName() + "'"); + op->emitError( + "unsupported by backend contract: Unimplemented operator '" + + op.getName() + "'"); } return WalkResult::interrupt(); }); @@ -309,20 +310,22 @@ class LowerToBackendContractPass << " iterations of the simplification pipeline\n"; }); } + private: llvm::StringSet<> backendLegalOpsSet; }; class VerifyBackendContractNoDecompositionsPass - : public VerifyBackendContractNoDecompositionsBase { + : public VerifyBackendContractNoDecompositionsBase< + VerifyBackendContractNoDecompositionsPass> { public: VerifyBackendContractNoDecompositionsPass() = default; void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target = - getBackendContractTarget(context, /*decompose*/false, - /*backendLegalOpsSet*/{}); + getBackendContractTarget(context, /*decompose*/ false, + /*backendLegalOpsSet*/ {}); if (!satisfiesBackendContract(getOperation(), target, /*actuallyEmitDiagnostics=*/true)) { @@ -376,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -386,12 +388,14 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -405,15 +409,22 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, }); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -425,16 +436,21 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -475,6 +491,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -487,6 +504,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -504,6 +522,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp new file mode 100644 index 000000000000..147f16c08eb3 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +Type getQuantizedType(MLIRContext *context, Type t) { + if (t.isSignlessInteger(8)) + return Torch::QUInt8Type::get(context); + if (t.isInteger(8) || t.isSignedInteger(8)) + return Torch::QInt8Type::get(context); + if (t.isInteger(32)) + return Torch::QInt32Type::get(context); + return {}; +} + +class MatchQuantizeOperator : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OperatorOp op, + PatternRewriter &rewriter) const override { + if (op.getName() == "torch.quantized_decomposed.quantize_per_tensor") { + auto resultTy = cast(op.getType(0)); + auto qeTy = getQuantizedType(rewriter.getContext(), resultTy.getDtype()); + if (!qeTy) + qeTy = resultTy.getDtype(); + + auto qTy = + rewriter.getType(resultTy.getOptionalSizes(), qeTy); + Value quant = rewriter.create( + op.getLoc(), qTy, + /*self=*/op.getOperand(0), /*scale=*/op.getOperand(1), + /*zero_point=*/op.getOperand(2), /*dtype=*/op.getOperand(5)); + + if (qTy != resultTy) { + quant = rewriter.create(op.getLoc(), resultTy, quant); + } + + rewriter.replaceOpWithNewOp( + op, resultTy, quant, op.getOperand(3), op.getOperand(4)); + return success(); + } + + if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") { + auto clamp = rewriter.create( + op.getLoc(), op.getOperand(0).getType(), op.getOperand(0), + op.getOperand(3), op.getOperand(4)); + + auto clampTy = clamp.getType().cast(); + if (!clampTy.hasDtype()) + return rewriter.notifyMatchFailure(op, + "dequantization has unknown dtype"); + + Type dtype = clampTy.getDtype(); + Type qetype = getQuantizedType(op.getContext(), dtype); + if (!qetype) + return rewriter.notifyMatchFailure(op, + "dequantization has unknown qtype"); + + Type qTy = Torch::ValueTensorType::get( + op.getContext(), clampTy.getOptionalSizes(), qetype); + auto quant = rewriter.create( + op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2)); + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), quant); + return success(); + } + + return failure(); + } +}; + +class MatchQuantizedCustomOpsPass + : public MatchQuantizedCustomOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.insert(context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createMatchQuantizedCustomOpsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index cd76275a745d..7db6bc6776b3 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -175,7 +175,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock // Replace return type of view-like ops with value-semantics type variant. for (Operation *viewLikeOp : ops.viewLikeOps) { - rewriter.updateRootInPlace(viewLikeOp, [&] { + rewriter.modifyOpInPlace(viewLikeOp, [&] { Value result = viewLikeOp->getResult(0); auto resultType = result.getType().dyn_cast(); if (resultType) @@ -337,7 +337,7 @@ class RewriteViewLikeSubgraph // correctly copy them back to their mlir::func::ReturnOp's expected types. DenseMap originalTypes; for (Operation *op : viewLikeOps) { - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { if (auto nonValueTensorType = op->getResult(0).getType().dyn_cast()) { originalTypes[op->getResult(0)] = nonValueTensorType; diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index a4b02cf9e17f..f8161de1fa0b 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -9,10 +9,10 @@ #include "PassDetail.h" +#include "ReifyAbstractInterpCalculationsUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" -#include "ReifyAbstractInterpCalculationsUtils.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; @@ -72,8 +72,8 @@ namespace { // immutable tensors. class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { public: - ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context, - const std::optional& extraLibrary) + ConvertHasValueSemanticsOpsToValueTensors( + MLIRContext *context, const std::optional &extraLibrary) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) { this->extraLibrary = extraLibrary; } @@ -87,7 +87,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { return rewriter.notifyMatchFailure(op, "does not have value semantics"); } - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); // Convert all operands. SmallVector newOperands; for (OpOperand &opOperand : op->getOpOperands()) { @@ -105,7 +105,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { auto listConstruct = opOperand.get().getDefiningOp(); if (!listConstruct) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list of non vtensor type not constructed " "from list construct"); @@ -120,7 +120,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { if (!llvm::all_of(listConstruct.getElements(), [](Value val) { return val.getType().isa(); })) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list containing optional type is not " "handled."); @@ -138,7 +138,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { Type newListType = getContainerOrTensorTypeWithValueSemantics(listType); if (!newListType) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "Unable to convert list type to value semantics."); } @@ -154,7 +154,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { // from the non value tensor of the original optional value. auto derefine = opOperand.get().getDefiningOp(); if (!derefine) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: optional of non vtensor type not from " "derefine"); @@ -180,14 +180,87 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { rewriter.create(op->getLoc(), result); result.replaceAllUsesExcept(nonValueTensor, nonValueTensor); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return success(); } + private: std::optional extraLibrary; }; } // namespace +namespace { + +class TorchMatchSpecializedBackendOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + using HandlerFn = LogicalResult (*)(OperatorOp op, + ConversionPatternRewriter &rewriter); + + LogicalResult + matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (namedHandlers.contains(op.getNameAttr())) { + return namedHandlers.lookup(op.getNameAttr()).front()(op, rewriter); + } + + return failure(); + } + + static void + populateSpecializedConversions(TorchMatchSpecializedBackendOp &matcher); + + static std::unique_ptr + getPopulatedMatcher(MLIRContext *context) { + auto matcher = std::make_unique(context); + populateSpecializedConversions(*matcher); + return matcher; + }; + + void populate(StringRef name, HandlerFn fn) { + namedHandlers[StringAttr::get(getContext(), name)].push_back(fn); + } + + void populateLegalizedNames(llvm::DenseSet &set) { + for (auto handle : namedHandlers) { + set.insert(handle.first); + } + } + +private: + DenseMap> namedHandlers; +}; + +void TorchMatchSpecializedBackendOp::populateSpecializedConversions( + TorchMatchSpecializedBackendOp &matcher) { + matcher.populate( + "torch.aten._scaled_dot_product_flash_attention_for_cpu", + [](Torch::OperatorOp op, + ConversionPatternRewriter &rewriter) -> LogicalResult { + auto uses = op.getResult(1).getUses(); + if (uses.end() == uses.begin()) { + auto oldOperands = op->getOperands(); + llvm::SmallVector newOperands{ + oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5], + oldOperands[3], oldOperands[4], oldOperands[6]}; + + auto newOp = rewriter.create( + op.getLoc(), op->getResultTypes()[0], newOperands, + op->getAttrs()); + rewriter.replaceAllUsesWith(op.getResult(0), newOp.getResult()); + rewriter.eraseOp(op); + return success(); + } + return failure(); + }); +} + +bool isSpecializedOperation(Torch::OperatorOp op) { return true; } +} // namespace + // Reduce Ops without value semantics but the corresponding without trailing // underscore variant doesn't exist. namespace { @@ -274,7 +347,7 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { SmallVector fragments; llvm::SplitString(op->getName().getStringRef(), fragments, "."); - assert(fragments.size() >= 3 && fragments[2].endswith("_") && + assert(fragments.size() >= 3 && fragments[2].ends_with("_") && "IsTrailingUnderscoreInplaceVariant incorrectly applied"); fragments[2] = fragments[2].drop_back(); std::string noUnderscoreName = llvm::join(fragments, "."); @@ -290,9 +363,9 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { Operation *newOp = rewriter.create(state); // Note: need to convert result to first input's dtype because mix precision // compute would result in different behaviors. - // For example: - // a = torch.randn(3, 3).half() # float16 - // b = torch.randn(3, 3) # float32 + // For example: + // a = torch.randn(3, 3).half() # float16 + // b = torch.randn(3, 3) # float32 // a += b # i.e. torch.ops.aten.add_(a, b), result is float16 // c = a + b # i.e. torch.ops.aten.add(a, b), result is float32 Value none = rewriter.create(op->getLoc()); @@ -300,7 +373,8 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { auto aDtype = rewriter.create(op->getLoc(), op->getOperand(0)); auto toDtype = rewriter.create( op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0), - aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); + aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); auto tensor = rewriter.create(op->getLoc(), toDtype); createOverwriteTensorContents(rewriter, op->getLoc(), tensor, op->getOperand(0)); @@ -351,12 +425,24 @@ struct ReduceOpVariantsPass patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp); patterns.add(context); + // Create specialized matcher: + auto specialized = + TorchMatchSpecializedBackendOp::getPopulatedMatcher(context); + DenseSet specializedNames; + specialized->populateLegalizedNames(specializedNames); + patterns.insert(std::move(specialized)); + ConversionTarget target(*context); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable]( - Operation *op) { + target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable, + &specializedNames](Operation *op) { + if (isa(op)) { + if (specializedNames.contains(cast(op).getNameAttr())) { + return false; + } + } if (op->hasTrait() || (isa(op) && operatorOpHasValueSemantics(cast(op), @@ -375,6 +461,9 @@ struct ReduceOpVariantsPass if (op->hasTrait()) { return false; } + + if (isa(op) && isSpecializedOperation(cast(op))) + return false; return true; }); diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 290beb1da7c9..a34e0208c9d9 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -78,7 +78,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( // mechanically consistent with existing torch conventions of in-place vs. // out-of-place (value-semantic) variants), remove the prefix when // looking them up in the library. - if (name.startswith("valsem.")) + if (name.starts_with("valsem.")) name = name.drop_front(strlen("valsem.")); if (isa(op)) name = cast(op)->getAttr("name").cast().getValue(); @@ -158,9 +158,11 @@ void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library, } } -FailureOr Torch::adjustFunctionArg( - OpBuilder &b, Location loc, Value operand, Type desiredType, - function_ref baseTransformation) { +FailureOr +Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, + Type desiredType, + function_ref + baseTransformation) { operand = baseTransformation(b, loc, operand, desiredType); // No need for adjustment if they already match. diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 6860fbb6eee8..fbbd6c48043b 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -90,7 +90,8 @@ class DecomposePromoteDtypesOp : public OpRewritePattern { PatternRewriter &rewriter) const override { SmallVector> ranks; SmallVector dtypes; - if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) { + if (!matchPattern(op.getRanks(), + m_TorchListOfOptionalConstantInts(ranks))) { return rewriter.notifyMatchFailure( op, "Expected `ranks` to be a list of optional constant ints"); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 5bd254d72be1..2993a2697360 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; using namespace mlir::torch; @@ -100,6 +101,12 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Byte; if (type.isSignedInteger(8)) return torch_upstream::ScalarType::Char; + if (type.isa()) + return torch_upstream::ScalarType::QUInt8; + if (type.isa()) + return torch_upstream::ScalarType::QInt8; + if (type.isa()) + return torch_upstream::ScalarType::QInt32; if (type.isa()) { mlir::Type complexElemType = type.cast().getElementType(); if (complexElemType.isF16()) @@ -111,7 +118,6 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { } llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } - Type Torch::getTypeForTorchType( MLIRContext *context, Type type, mlir::IntegerType::SignednessSemantics signedness) { @@ -146,6 +152,12 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Signed); + case torch_upstream::ScalarType::QUInt8: + return QUInt8Type::get(context); + case torch_upstream::ScalarType::QInt8: + return QInt8Type::get(context); + case torch_upstream::ScalarType::QInt32: + return QInt32Type::get(context); case torch_upstream::ScalarType::ComplexHalf: return mlir::ComplexType::get(Float16Type::get(context)); case torch_upstream::ScalarType::ComplexFloat: @@ -247,7 +259,7 @@ bool Torch::isViewLikeOp(Operation *op) { AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp, - AtenPixelShuffleOp>(op); + AtenPixelShuffleOp, AtenDiagonalOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, @@ -369,9 +381,9 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, // Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If // yes, then computes the final broadcast shape. void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, - SmallVector &resultShape, - SmallVector &resultShapeValue) { + Value inputA, Value inputB, + SmallVector &resultShape, + SmallVector &resultShapeValue) { SmallVector shapeA{ inputA.getType().cast().getSizes()}; SmallVector shapeB{ @@ -508,3 +520,75 @@ LogicalResult Torch::checkDefaultStrideHelper(Operation *op, return success(); } } + +// Helper to create a tensor filled with the given scalar. Scalar would be +// converted the to the element type of the given tensor type. +Value Torch::createInitTensor(PatternRewriter &rewriter, Location loc, + BaseTensorType resultType, Value scalar, + Value sizeList) { + assert(resultType.hasDtype() && "result must have dtype"); + Value noneVal = rewriter.create(loc); + Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); + return rewriter.create(loc, resultType, sizeList, scalar, dtype, + /*layout=*/noneVal, + /*device=*/noneVal, + /*memory_format=*/noneVal); +} + +// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` +// would be converted to the element type of the given `inputType`. +Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, + BaseTensorType inputType, Value scalar) { + assert(inputType.hasDtype() && "input must have dtype"); + SmallVector sizes; + BaseTensorType rank0TensorTy = + inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) + .cast(); + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), + ValueRange{}); + return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); +} + +LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, + int64_t dimB, Type &transposedType) { + if (!inType.hasSizes()) + return failure(); + SmallVector shape(inType.getSizes()); + int64_t tmp = shape[dimA]; + shape[dimA] = shape[dimB]; + shape[dimB] = tmp; + transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape), + inType.getOptionalDtype()); + return success(); +} + +Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { + if (inputType.isF16()) + return rewriter.getF32Type(); + if (inputType.isBF16()) + return rewriter.getF32Type(); + if (inputType.isa()) + return rewriter.getF32Type(); + if (inputType.isa()) + return rewriter.getF64Type(); + if (inputType.isFloat8E5M2()) + return rewriter.getF32Type(); + if (inputType.isFloat8E4M3FN()) + return rewriter.getF32Type(); + if (inputType.isFloat8E5M2FNUZ()) + return rewriter.getF32Type(); + if (inputType.isFloat8E4M3FNUZ()) + return rewriter.getF32Type(); + if (inputType.isSignedInteger(8)) + return rewriter.getI64Type(); + if (inputType.isUnsignedInteger(8)) + return rewriter.getI64Type(); + if (inputType.isSignedInteger(16)) + return rewriter.getI64Type(); + if (inputType.isSignedInteger(32)) + return rewriter.getI64Type(); + if (inputType.isSignedInteger(64)) + return rewriter.getI64Type(); + llvm::report_fatal_error("unhandled type for getDefaultAccType"); +} diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 4d38f4965df2..ac9a72586bef 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -54,14 +54,14 @@ void TorchConversionDialect::initialize() { addInterfaces(); } - //===----------------------------------------------------------------------===// // Constant materializer. //===----------------------------------------------------------------------===// Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { + Attribute value, + Type type, + Location loc) { if (auto integerType = type.dyn_cast()) return builder.create(loc, value.cast()); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 8a5c218e4f3e..1cda55724ee3 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; @@ -57,16 +57,16 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::BoolType type) -> std::optional { return IntegerType::get(type.getContext(), 1); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 1 && type.isSignless())) - return std::nullopt; - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 1 && type.isSignless())) + return std::nullopt; + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,19 +83,19 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::IntType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return std::nullopt; + // Other input type to be converted to i64 are handled by other + // materializers. + if (!inputs[0].getType().isa()) + return std::nullopt; + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -112,13 +112,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::FloatType type) -> std::optional { return Float64Type::get(type.getContext()); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - Float64Type type, ValueRange inputs, - Location loc) -> std::optional { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, Float64Type type, ValueRange inputs, + Location loc) -> std::optional { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -133,22 +133,23 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { target.addLegalOp(); - typeConverter.addConversion([](Torch::GeneratorType type) -> std::optional { - return IntegerType::get(type.getContext(), 64); - }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addConversion( + [](Torch::GeneratorType type) -> std::optional { + return IntegerType::get(type.getContext(), 64); + }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return std::nullopt; + // Other input type to be converted to i64 are handled by other + // materializers. + if (!inputs[0].getType().isa()) + return std::nullopt; + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 5f3a2609be8c..896dd9577617 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -115,6 +115,21 @@ static void setupFinalization(ConversionTarget &target, setupFinalization(target, patterns, typeConverter); } +static void stripTorchAttrs(FunctionOpInterface func) { + bool modified = false; + SmallVector newAttrs; + for (auto attr : func->getDialectAttrs()) { + if (attr.getName().getValue().starts_with("torch.")) + modified = true; + else + newAttrs.push_back(attr); + } + if (modified) + func->setDialectAttrs(newAttrs); + + // Note: this could also strip "arg" and "result" attrs if they were used. +} + namespace { struct FinalizingBackendTypeConversionPass : public FinalizingBackendTypeConversionBase< @@ -151,11 +166,14 @@ struct FinalizingBackendTypeConversionPass if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); + + // Drop attributes that are no longer used after conversion out of Torch. + stripTorchAttrs(func); } }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 175a3cd14804..7bcb67b17c61 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -18,8 +18,8 @@ #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; @@ -65,7 +65,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { auto getConstantIntegerFromDefiningOp = [](Value operand, int &extractedInt) { - auto castOp = dyn_cast(operand.getDefiningOp()); + auto castOp = + dyn_cast(operand.getDefiningOp()); if (!castOp) { return failure(); } @@ -83,7 +84,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { return failure(); } int unpackedBitWidth; - if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) { + if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, + unpackedBitWidth))) { return failure(); } if (unpackedBitWidth != @@ -103,32 +105,35 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { // expand lhs std::vector lhsExpandedShape = {lhsShape[0], lhsShape[1], lhsReductDimSize / gs, gs}; - RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType); + RankedTensorType lhsExpandedType = + RankedTensorType::get(lhsExpandedShape, elementType); SmallVector lhsReassociation = {{0}, {1}, {2, 3}}; Value lhsExpanded = rewriter.create( - loc, lhsExpandedType, lhs, lhsReassociation); + loc, lhsExpandedType, lhs, lhsReassociation); // expand rhs - std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs}; - RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType); + std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize / gs, + gs}; + RankedTensorType rhsExpandedType = + RankedTensorType::get(rhsExpandedShape, rhsElementType); SmallVector rhsReassociation = {{0}, {1, 2}}; Value rhsExpanded = rewriter.create( - loc, rhsExpandedType, rhsQuant, rhsReassociation); + loc, rhsExpandedType, rhsQuant, rhsReassociation); Value cst0 = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); + loc, FloatAttr::get(elementType, 0.0)); - Value emptyDequant = rewriter.create( - loc, rhsExpandedShape, elementType); + Value emptyDequant = + rewriter.create(loc, rhsExpandedShape, elementType); SmallVector dynDims; for (int i = 0; i < lhsType.getRank(); i++) { if (lhsType.isDynamicDim(i)) { dynDims.push_back(rewriter.create(loc, lhs, i)); } } - Value empty = rewriter.create( - loc, resultShape, elementType, dynDims); - Value output = rewriter.create( - loc, cst0, empty).getResult(0); + Value empty = rewriter.create(loc, resultShape, + elementType, dynDims); + Value output = + rewriter.create(loc, cst0, empty).getResult(0); AffineExpr d0, d1, d2, d3, d4; bindDims(getContext(), d0, d1, d2, d3, d4); @@ -141,12 +146,12 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { SmallVector dqIndexingMaps = {map, map1, map1, map}; SmallVector matIndexingMaps = {map2, map3, map4}; - SmallVector dequantIteratorTypes(3, utils::IteratorType::parallel); + SmallVector dequantIteratorTypes( + 3, utils::IteratorType::parallel); SmallVector matmulIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::reduction, - utils::IteratorType::reduction - }; + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction}; Value rhsDequant = rewriter @@ -157,9 +162,12 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { /*iteratorTypes=*/dequantIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value w = args[0], scale = args[1], zeroPoint = args[2]; - Value extw = b.create(loc, rewriter.getI32Type(), w); - Value fp_extw = b.create(loc, rewriter.getF16Type(), extw); - Value shifted = b.create(loc, fp_extw, zeroPoint); + Value extw = + b.create(loc, rewriter.getI32Type(), w); + Value fp_extw = b.create( + loc, rewriter.getF16Type(), extw); + Value shifted = + b.create(loc, fp_extw, zeroPoint); Value dqw = b.create(loc, shifted, scale); b.create(loc, dqw); }) @@ -168,8 +176,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { Value matmulDequant = rewriter .create( - loc, output.getType(), - ValueRange{lhsExpanded, rhsDequant}, output, + loc, output.getType(), ValueRange{lhsExpanded, rhsDequant}, + output, /*indexingMaps=*/matIndexingMaps, /*iteratorTypes=*/matmulIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -188,7 +196,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { namespace { class ConvertCustomQuantOpPass - : public TorchConversion::ConvertCustomQuantOpBase { + : public TorchConversion::ConvertCustomQuantOpBase< + ConvertCustomQuantOpPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -213,14 +222,14 @@ class ConvertCustomQuantOpPass target.addIllegalOp(); patterns.add(typeConverter, context); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h index 224ad8e2d89a..cb80ebd89a3c 100644 --- a/lib/Dialect/TorchConversion/Transforms/PassDetail.h +++ b/lib/Dialect/TorchConversion/Transforms/PassDetail.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 09e99057e0b6..55bedc1192eb 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -9,18 +9,21 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif @@ -64,14 +67,20 @@ void mlir::torch::registerTorchConversionPasses() { void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( OpPassManager &pm) { + // We want to fuse quantized operations together before lowering to linalg. + pm.addNestedPass(Torch::createFuseQuantizedOpsPass()); + // Lower to linalg + guards which is the input to codegen backends. // We do this first as it tends to involve pattern-matching against constants, // (e.g. dimensions which must be constant in a ranked programming model) // and those constants get somewhat obscured by TorchToArith. pm.addNestedPass(createConvertTorchToTMTensorPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToLinalgPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); + pm.addNestedPass(createConvertTorchToTensorPass()); pm.addPass(createConvertTorchConversionToMLProgramPass()); pm.addNestedPass(memref::createExpandOpsPass()); diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 25f325399f12..064c87f6e6a8 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -137,7 +137,7 @@ class UnpackQuantTensorPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createUnpackQuantTensorPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index 93d7de8250a7..5ad3fa1c9f4f 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -33,7 +33,6 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace TMTensor; - namespace { class VerifyLinalgOnTensorsBackendContractPass : public VerifyLinalgOnTensorsBackendContractBase< @@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics // doesn't unnecessarily spew out the entire module. emitError(module.getLoc()) - << "Module does not conform to the linalg-on-tensors backend contract. " + << "Module does not conform to the linalg-on-tensors backend " + "contract. " "See dialect conversion legality information above."; return signalPassFailure(); } diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 888f29adedb2..c6085f419eac 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -45,7 +45,8 @@ class VerifyStablehloBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp(opHasLegalTypes); + target.addDynamicallyLegalOp( + opHasLegalTypes); // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index ace6c1a40e74..1205d6343e43 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Dialect.h" @@ -29,6 +30,11 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "stablehlo/conversions/linalg/transforms/Passes.h" +#include "stablehlo/transforms/Passes.h" +#endif + void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); @@ -42,7 +48,8 @@ void mlir::torch::registerOptionalInputDialects( mlir::DialectRegistry ®istry) { registry.insert(); + scf::SCFDialect, tensor::TensorDialect, tosa::TosaDialect, + sparse_tensor::SparseTensorDialect>(); } void mlir::torch::registerAllPasses() { @@ -52,6 +59,11 @@ void mlir::torch::registerAllPasses() { mlir::torch::onnx_c::registerTorchOnnxToTorchPasses(); mlir::torch::TMTensor::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_STABLEHLO + mlir::stablehlo::registerChloLegalizeToStablehloPass(); + mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); +#endif + #ifdef TORCH_MLIR_ENABLE_REFBACKEND mlir::torch::RefBackend::registerRefBackendPasses(); #endif diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 481bdf3426d8..4ada196e944c 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -20,10 +20,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -436,6 +438,29 @@ mlir::torch::RefBackend::createMungeMemrefCopyPass() { return std::make_unique(); } +namespace { +class GeneralizeTensorConcat + : public GeneralizeTensorConcatBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + tensor::populateDecomposeTensorConcatPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::RefBackend::createGeneralizeTensorConcatPass() { + return std::make_unique(); +} + namespace { class GeneralizeTensorPad : public GeneralizeTensorPadBase { diff --git a/projects/CMakeLists.txt b/projects/CMakeLists.txt index 4b54be65a79d..ea7e34593aba 100644 --- a/projects/CMakeLists.txt +++ b/projects/CMakeLists.txt @@ -1,7 +1,35 @@ include(AddMLIRPython) +if(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER) + add_subdirectory(onnx_c_importer) +endif() + +################################################################################ +# PyTorch # Configure PyTorch if we have any features enabled which require it. +################################################################################ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) + + if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) + # Source builds + message(STATUS "Building libtorch from source (features depend on it and NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)") + set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) + set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH}) + set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD}) + set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) + set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) + set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) + set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) + execute_process( + COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh + RESULT_VARIABLE _result + ) + if(_result) + message(FATAL_ERROR "Failed to run `build_libtorch.sh`") + endif() + set(TORCH_INSTALL_PREFIX "libtorch") + endif() + message(STATUS "Enabling PyTorch C++ dep (features depend on it)") include(TorchMLIRPyTorch) @@ -48,6 +76,6 @@ if(TORCH_MLIR_ENABLE_LTC) endif() # Include overall PT1 project. -if(TORCH_MLIR_ENABLE_PROJECT_PT1) +if(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS) add_subdirectory(pt1) endif() diff --git a/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp index b144e946ba5e..47f7a974c8e8 100644 --- a/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp @@ -9,6 +9,7 @@ #include "class_annotator.h" +#include #include using namespace torch_mlir; @@ -150,11 +151,26 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { } static void fillArgAnnotations(MethodAnnotation &methodAnnotation, - std::vector argAnnotations, + const std::vector &argAnnotations, torch::jit::Function *function) { if (argAnnotations.size() != function->num_inputs()) { - throw std::invalid_argument("Arg annotations should have one entry per " - "function parameter (including self)."); + + std::ostringstream oss; + oss << "There must be one argument annotation per function parameter. " + << "Including 'self' the number of argument annotations is: " + << argAnnotations.size() + << ". The number of function parameters is: " << function->num_inputs() + << ". "; + const auto &args = function->getSchema().arguments(); + if (args.size() > 0) { + oss << "The function signature is ("; + oss << args[0]; + for (auto iter = args.begin() + 1; iter != args.end(); iter++) { + oss << ", " << *iter; + } + oss << ')' << '.'; + } + throw std::invalid_argument(oss.str()); } if (!methodAnnotation.argAnnotations.has_value()) { methodAnnotation.argAnnotations.emplace(function->num_inputs(), diff --git a/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt index eee3044f0fc9..2bbdbd233344 100644 --- a/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt +++ b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt @@ -2,11 +2,21 @@ # Setup PyTorch/LTC #------------------------------------------------------------------------------- +torch_mlir_enable_werror() + set(LTC_GENERATED generated/LazyNativeFunctions.cpp generated/RegisterLazy.cpp generated/shape_inference.cpp ) + +# The auto generated files trigger some warnings we can't do anything about. +if(NOT MSVC) + set_source_files_properties(${LTC_GENERATED} + PROPERTIES COMPILE_FLAGS "-Wno-sign-compare -Wno-unused-function" + ) +endif() + set(LTC_BACKEND_DEPENDS mlir_lowering_context.cpp mlir_native_functions.cpp diff --git a/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp b/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp index bd4fe52b7b22..dc044879669e 100644 --- a/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp +++ b/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp @@ -31,18 +31,18 @@ TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape) PRINT_FUNCTION(); } TorchMlirBackendData::TorchMlirBackendData( - BackendDevice device, Shape shape, std::shared_ptr info) + BackendDevice device, Shape shape, std::shared_ptr info) : BackendData(device, shape), info_(info) { PRINT_FUNCTION(); } -TorchMlirBackendData::TorchMlirBackendData( - const at::Scalar& scalar, BackendDevice device) +TorchMlirBackendData::TorchMlirBackendData(const at::Scalar &scalar, + BackendDevice device) : BackendData(device, Shape(scalar.type(), {})), info_(std::make_shared(scalar)) { PRINT_FUNCTION(); } -TorchMlirBackendData::TorchMlirBackendData( - const at::Tensor& tensor, BackendDevice device, Shape shape) +TorchMlirBackendData::TorchMlirBackendData(const at::Tensor &tensor, + BackendDevice device, Shape shape) : BackendData(device, shape), info_(std::make_shared(tensor)) { PRINT_FUNCTION(); @@ -52,19 +52,18 @@ BackendData::Handle TorchMlirBackendData::GetHandle() { return reinterpret_cast(this); } -void TorchMlirBackendData::Assign(const BackendData& data) { - const TorchMlirBackendData* torch_mlir_data = - dynamic_cast(&data); - TORCH_CHECK( - torch_mlir_data, - "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); +void TorchMlirBackendData::Assign(const BackendData &data) { + const TorchMlirBackendData *torch_mlir_data = + dynamic_cast(&data); + TORCH_CHECK(torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); info_ = torch_mlir_data->info_; } bool TorchMlirBackendData::HasValue() const { return bool(info_); } -BackendData::Info* TorchMlirBackendData::mlir_info() const { +BackendData::Info *TorchMlirBackendData::mlir_info() const { return info_.get(); } @@ -77,8 +76,8 @@ void TorchMlirBackendImpl::PrepareToExit() const {} * IR Tracing * */ -const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const { - static const IrBuilder* builder = new TorchMlirIrBuilder(); +const IrBuilder *TorchMlirBackendImpl::GetIrBuilder() const { + static const IrBuilder *builder = new TorchMlirIrBuilder(); return builder; } @@ -87,28 +86,29 @@ const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const { * */ BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor( - const at::Tensor& tensor, const Shape& shape, - const BackendDevice& device) const { + const at::Tensor &tensor, const Shape &shape, + const BackendDevice &device) const { PRINT_FUNCTION(); return std::make_shared(tensor, device, shape); } BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar( - const at::Scalar& scalar, const BackendDevice& device) const { + const at::Scalar &scalar, const BackendDevice &device) const { PRINT_FUNCTION(); return std::make_shared(scalar, device); } -BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape) const { +BackendDataPtr +TorchMlirBackendImpl::CreateDataPlaceholder(const BackendDevice &device, + const Shape &shape) const { PRINT_FUNCTION(); return std::make_shared(device, shape); } BackendDataPtr -TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const { +TorchMlirBackendImpl::GetComputationDataFromNode(const Node *node) const { PRINT_FUNCTION(); - const auto* device_data_node = dynamic_cast(node); + const auto *device_data_node = dynamic_cast(node); if (!device_data_node) { return nullptr; } @@ -120,14 +120,13 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( c10::optional logical_scalar_type) const { PRINT_FUNCTION(); - TorchMlirBackendData* torch_mlir_data = - dynamic_cast(data.get()); - TORCH_CHECK( - torch_mlir_data, - "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + TorchMlirBackendData *torch_mlir_data = + dynamic_cast(data.get()); + TORCH_CHECK(torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); - TorchMlirBackendData::Info* info = - dynamic_cast(torch_mlir_data->mlir_info()); + TorchMlirBackendData::Info *info = + dynamic_cast(torch_mlir_data->mlir_info()); TORCH_CHECK( info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); @@ -140,17 +139,19 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( * */ std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, Util::EmissionMap emit_status) const { + const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) const { PRINT_FUNCTION(); return std::make_unique( name, std::forward(device), - std::forward>(post_order), + std::forward>(post_order), std::forward(emit_status)); } -std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( - const std::string& name, BackendDevice device) const { +std::unique_ptr +TorchMlirBackendImpl::CreateLoweringContext(const std::string &name, + BackendDevice device) const { PRINT_FUNCTION(); return std::make_unique( name, std::forward(device)); @@ -175,9 +176,8 @@ at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const { // Query all available backend devices std::vector TorchMlirBackendImpl::GetBackendDevices() const { PRINT_FUNCTION(); - return { - GetBackendDevice(c10::Device(c10::kLazy, 0)), - GetBackendDevice(c10::Device(c10::kCPU, 0))}; + return {GetBackendDevice(c10::Device(c10::kLazy, 0)), + GetBackendDevice(c10::Device(c10::kCPU, 0))}; } // Map a particular c10:: device to a concrete backend device diff --git a/projects/ltc/csrc/base_lazy_backend/backend_impl.h b/projects/ltc/csrc/base_lazy_backend/backend_impl.h index c77033593ba3..4029cab1ea90 100644 --- a/projects/ltc/csrc/base_lazy_backend/backend_impl.h +++ b/projects/ltc/csrc/base_lazy_backend/backend_impl.h @@ -41,27 +41,28 @@ class TORCH_API TorchMlirBackendData : public BackendData { name = ss.str(); ++i; } - Info(const Info& other) + Info(const Info &other) : tensor{other.tensor}, scalar{other.scalar}, requires_grad{other.requires_grad}, name{other.name} {} - Info(const at::Tensor& tensor) + Info(const at::Tensor &tensor) : tensor{tensor}, requires_grad{tensor.requires_grad()} {} - Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {} + Info(const at::Scalar &scalar) : scalar{scalar}, requires_grad(false) {} }; TorchMlirBackendData(BackendDevice device, Shape shape); - TorchMlirBackendData(BackendDevice device, Shape shape, std::shared_ptr info); - TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device); - TorchMlirBackendData( - const at::Tensor& tensor, BackendDevice device, Shape shape); + TorchMlirBackendData(BackendDevice device, Shape shape, + std::shared_ptr info); + TorchMlirBackendData(const at::Scalar &scalar, BackendDevice device); + TorchMlirBackendData(const at::Tensor &tensor, BackendDevice device, + Shape shape); virtual BackendData::Handle GetHandle() override; - virtual void Assign(const BackendData& data) override; + virtual void Assign(const BackendData &data) override; virtual bool HasValue() const override; - BackendData::Info* mlir_info() const; + BackendData::Info *mlir_info() const; protected: std::shared_ptr info_; @@ -80,7 +81,7 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { * IR Tracing * */ - const IrBuilder* GetIrBuilder() const override; + const IrBuilder *GetIrBuilder() const override; /** * Configuration @@ -91,19 +92,22 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { * Data Transfer * */ - virtual BackendDataPtr MakeComputationDataFromTensor( - const at::Tensor& tensor, const Shape& shape, - const BackendDevice& device) const override; + virtual BackendDataPtr + MakeComputationDataFromTensor(const at::Tensor &tensor, const Shape &shape, + const BackendDevice &device) const override; - virtual BackendDataPtr MakeComputationDataFromScalar( - const at::Scalar& scalar, const BackendDevice& device) const override; + virtual BackendDataPtr + MakeComputationDataFromScalar(const at::Scalar &scalar, + const BackendDevice &device) const override; - virtual BackendDataPtr CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape) const override; + virtual BackendDataPtr + CreateDataPlaceholder(const BackendDevice &device, + const Shape &shape) const override; // Gets backend data if the node is a device data node. Otherwise returns // nullptr. - virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override; + virtual BackendDataPtr + GetComputationDataFromNode(const Node *) const override; virtual at::Tensor MakeTensorFromComputationData( const BackendDataPtr data, @@ -113,13 +117,14 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { * Lowering, Compilation, Execution * */ - virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, - Util::EmissionMap emit_status) const override; + virtual std::unique_ptr + CreateLoweringContext(const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) const override; - virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device) const override; + virtual std::unique_ptr + CreateLoweringContext(const std::string &name, + BackendDevice device) const override; // TODO(whc) need to keep this? // virtual std::vector GetCompilationDevices( diff --git a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp index ca6d80f1f419..363bac959281 100644 --- a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp +++ b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp @@ -16,20 +16,18 @@ namespace torch { namespace lazy { DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed) - : TorchMlirNode( - op, operands, /*num_outputs=*/1, - /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} + : TorchMlirNode(op, operands, /*num_outputs=*/1, + /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} std::string DimensionNode::ToString() const { return "DimensionNode"; } SizeNode::SizeNode(Value input, size_t dim) - : DimensionNode( - OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, - MHash(dim)), - dim_(dim){}; + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, + MHash(dim)), + dim_(dim) {} int64_t SizeNode::getStaticValue() const { - return dynamic_cast(operand(0).node) + return dynamic_cast(operand(0).node) ->shape(0) .size(dim_); } @@ -37,35 +35,38 @@ int64_t SizeNode::getStaticValue() const { std::string SizeNode::ToString() const { return "SizeNode"; } SizeAdd::SizeAdd(Value a, Value b) - : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){}; + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}) {} int64_t SizeAdd::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() + - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() + + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeAdd::ToString() const { return "SizeAdd"; } SizeMul::SizeMul(Value a, Value b) - : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){}; + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}) {} int64_t SizeMul::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() * - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() * + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeMul::ToString() const { return "SizeMul"; } SizeDiv::SizeDiv(Value a, Value b) - : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){}; + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}) {} int64_t SizeDiv::getStaticValue() const { TORCH_CHECK( - dynamic_cast(operand(1).node)->getStaticValue() != + dynamic_cast(operand(1).node)->getStaticValue() != 0, "Can't divide a dimension by zero"); - return dynamic_cast(operand(0).node)->getStaticValue() / - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() / + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeDiv::ToString() const { return "SizeDiv"; } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp index 7e6f40c5c2e9..a27889ad0895 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -12,14 +12,14 @@ #include +#include "mlir-c/IR.h" +#include "mlir-c/Pass.h" +#include "torch-mlir-c/Registration.h" +#include "torch-mlir-c/Transforms.h" #include #include -#include #include -#include "torch-mlir-c/Registration.h" -#include "torch-mlir-c/Transforms.h" -#include "mlir-c/IR.h" -#include "mlir-c/Pass.h" +#include #include "backend_impl.h" #include "jit_ir_importer/function_importer.h" @@ -38,8 +38,8 @@ namespace lazy { // TorchMlir Lowering Context /////////////////////////////////////////////////////////////////////////////// -TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device) +TorchMlirLoweringContext::TorchMlirLoweringContext(const std::string &name, + BackendDevice device) : LoweringContext(name, std::forward(device)), graph_(std::make_shared()), function_( @@ -49,11 +49,12 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( } TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, Util::EmissionMap emit_status) + const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) : LoweringContext( name, std::forward(device), - std::forward>(post_order), + std::forward>(post_order), std::forward(emit_status)), graph_(std::make_shared()), function_( @@ -66,9 +67,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( } } -void TorchMlirLoweringContext::Lower(const Node* node) { - if (auto* torch_mlir_node = - dynamic_cast(node)) { +void TorchMlirLoweringContext::Lower(const Node *node) { + if (auto *torch_mlir_node = + dynamic_cast(node)) { TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this); CHECK(!ops.empty()) << "Failed to lower: " << *node; TORCH_CHECK_EQ(node->num_outputs(), ops.size()); @@ -82,19 +83,19 @@ void TorchMlirLoweringContext::Lower(const Node* node) { } void TorchMlirLoweringContext::SetUpAlias( - const std::vector& output_index, int64_t param_number, - const std::vector& param_index, bool must_alias) { + const std::vector &output_index, int64_t param_number, + const std::vector ¶m_index, bool must_alias) { input_output_aliases_.push_back( {output_index, param_number, param_index, must_alias}); } bool TorchMlirLoweringContext::CheckResultShape( - const BackendDataPtr& parameter_data, size_t result_idx) { - TORCH_CHECK( - result_idx < root_tuple_.size(), "Tried getting result shape at index ", - result_idx, " which is out of bounds!"); + const BackendDataPtr ¶meter_data, size_t result_idx) { + TORCH_CHECK(result_idx < root_tuple_.size(), + "Tried getting result shape at index ", result_idx, + " which is out of bounds!"); - torch::jit::Value* output = root_tuple_[result_idx]; + torch::jit::Value *output = root_tuple_[result_idx]; if (c10::TensorTypePtr tensor_type = output->type()->cast()) { @@ -111,7 +112,7 @@ bool TorchMlirLoweringContext::CheckResultShape( return false; } -size_t TorchMlirLoweringContext::AddResult(const Output& output) { +size_t TorchMlirLoweringContext::AddResult(const Output &output) { PRINT_FUNCTION(); return AddResult(GetOutputOp(output)); @@ -120,9 +121,10 @@ size_t TorchMlirLoweringContext::AddResult(const Output& output) { // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. -void TorchMlirLoweringContext::AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) { +void TorchMlirLoweringContext::AddParameter(const torch::lazy::Output &output, + size_t index, + const torch::lazy::Shape &shape, + const std::string &name) { UNIMPLEMENTED_FUNCTION_ERROR(); } @@ -136,7 +138,7 @@ ComputationPtr TorchMlirLoweringContext::Build() { torch::jit::RefineTupleTypes(graph_); // Insert return values into graph. - for (torch::jit::Value* output : root_tuple_) { + for (torch::jit::Value *output : root_tuple_) { graph_->block()->registerOutput(output); } @@ -152,7 +154,6 @@ ComputationPtr TorchMlirLoweringContext::Build() { /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); - // Convert MlirOperation to MlirModule. MlirLocation loc = mlirLocationUnknownGet(mlir_context_); MlirModule module_op = mlirModuleCreateEmpty(loc); @@ -162,14 +163,10 @@ ComputationPtr TorchMlirLoweringContext::Build() { // Apply passes to verify generated MLIR. auto pass_manager = mlirPassManagerCreate(mlir_context_); mlirPassManagerAddOwnedPass( - pass_manager, - mlirCreateVerifyBackendContractNoDecompositions() - ); + pass_manager, mlirCreateVerifyBackendContractNoDecompositions()); - MlirLogicalResult result = mlirPassManagerRunOnOp( - pass_manager, - mlirModuleGetOperation(module_op) - ); + MlirLogicalResult result = + mlirPassManagerRunOnOp(pass_manager, mlirModuleGetOperation(module_op)); if (mlirLogicalResultIsFailure(result)) { throw std::runtime_error("MLIR verification has failed."); @@ -178,12 +175,14 @@ ComputationPtr TorchMlirLoweringContext::Build() { return CreateComputation(module_op); } -ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { - return std::make_shared( - module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); +ComputationPtr +TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { + return std::make_shared(module_op, mlir_context_, + graph_, parameter_names_, + input_output_aliases_); } -torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { +torch::jit::Value *TorchMlirLoweringContext::GetOutputOp(const Output &output) { PRINT_FUNCTION(); auto it = emitted_outputs_.find(output); @@ -195,15 +194,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { // At this point the output better be present, otherwise there is an issue // with the lowering code. it = emitted_outputs_.find(output); - TORCH_CHECK( - it != emitted_outputs_.end(), - "No MLIR operation emitted for output: ", output.ToString()); + TORCH_CHECK(it != emitted_outputs_.end(), + "No MLIR operation emitted for output: ", output.ToString()); } return it->second; } -void TorchMlirLoweringContext::AssignOutputOp( - const Output& output, torch::jit::Value* op) { +void TorchMlirLoweringContext::AssignOutputOp(const Output &output, + torch::jit::Value *op) { PRINT_FUNCTION(); auto torch_mlir_node = @@ -211,48 +209,44 @@ void TorchMlirLoweringContext::AssignOutputOp( std::vector source_files, functions; std::vector line_numbers; - const auto& metadata = torch_mlir_node->metadata(); - const auto& frames = metadata.frame_info; + const auto &metadata = torch_mlir_node->metadata(); + const auto &frames = metadata.frame_info; if (!frames.empty()) { static std::vector g_roots = - string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); + string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); std::for_each(frames.rbegin(), frames.rend(), - [&](const torch::lazy::SourceLocation& location) { - functions.push_back(location.function); - line_numbers.push_back(location.line); - - std::string file_name = location.file; - for (const std::string& root : g_roots) { - if (startswith(file_name, root)) { - // location.file starts with root, strip it off - file_name = file_name.substr(root.size()); - break; - } - } - source_files.push_back(file_name); - }); + [&](const torch::lazy::SourceLocation &location) { + functions.push_back(location.function); + line_numbers.push_back(location.line); + + std::string file_name = location.file; + for (const std::string &root : g_roots) { + if (startswith(file_name, root)) { + // location.file starts with root, strip it off + file_name = file_name.substr(root.size()); + break; + } + } + source_files.push_back(file_name); + }); if (!source_files.empty()) { - op->node()->ss_( - c10::Symbol::attr("source_files"), source_files); - op->node()->ss_( - c10::Symbol::attr("functions"), functions); - op->node()->is_( - c10::Symbol::attr("line_numbers"), line_numbers); + op->node()->ss_(c10::Symbol::attr("source_files"), source_files); + op->node()->ss_(c10::Symbol::attr("functions"), functions); + op->node()->is_(c10::Symbol::attr("line_numbers"), line_numbers); } } auto scope = ::c10::Symbol::scope(metadata.scope); - op->node()->setScope( - c10::make_intrusive()->push(scope)); + op->node()->setScope(c10::make_intrusive()->push(scope)); emitted_outputs_[output] = std::move(op); } -torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { +torch::jit::Value *TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { PRINT_FUNCTION(); - if (!dynamic_cast(data.get())) { + if (!dynamic_cast(data.get())) { TORCH_CHECK( false, "Expected TorchMlirBackendData. Got some other BackendData type"); @@ -263,20 +257,21 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - torch::jit::Value* param = + torch::jit::Value *param = graph_->addInput(c10::str("p", parameters_.size())); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto *info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info, "Expected TorchMlirBackendData::Info"); if (info->scalar.has_value()) { - auto& scalar = info->scalar.value(); + auto &scalar = info->scalar.value(); if (scalar.isFloatingPoint()) { param->setType(c10::FloatType::get()); } else if (scalar.isIntegral(true)) { param->setType(c10::IntType::get()); } else { - TORCH_CHECK( - false, "Unhandled scalar type: ", c10::toString(scalar.type())); + TORCH_CHECK(false, + "Unhandled scalar type: ", c10::toString(scalar.type())); } } else { // Save parameter shape information. @@ -305,7 +300,7 @@ std::shared_ptr TorchMlirLoweringContext::graph() const { return graph_; } -size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { +size_t TorchMlirLoweringContext::AddResult(torch::jit::Value *op) { PRINT_FUNCTION(); root_tuple_.push_back(std::move(op)); return root_tuple_.size() - 1; @@ -313,9 +308,9 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { // Sync vector of c10::Argument with type specified from parallel list of // jit::Value. There must be a 1:1 map between elements of args and values. -std::vector sync_argument_types( - const std::vector& args, - c10::ArrayRef values) { +std::vector +sync_argument_types(const std::vector &args, + c10::ArrayRef values) { TORCH_CHECK( args.size() == values.size(), "Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ", @@ -362,7 +357,7 @@ void TorchMlirLoweringContext::RegisterMlirDialects() { TorchMlirComputation::TorchMlirComputation( MlirModule module_op, MlirContext mlir_context, - const std::shared_ptr& graph, + const std::shared_ptr &graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases) : module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)), @@ -377,26 +372,25 @@ TorchMlirComputation::TorchMlirComputation( } } -int TorchMlirComputation::parameters_size() const { - return num_parameters_; -} +int TorchMlirComputation::parameters_size() const { return num_parameters_; } -const std::vector& +const std::vector & TorchMlirComputation::parameter_shapes() const { throw std::runtime_error( "todo(whc) implement ts computation shapes or change interface"); return parameter_shapes_; } -const std::vector& TorchMlirComputation::parameter_names() const { +const std::vector &TorchMlirComputation::parameter_names() const { return parameter_names_; } -const std::unordered_map& TorchMlirComputation::parameters_map() const { +const std::unordered_map & +TorchMlirComputation::parameters_map() const { return parameters_map_; } -const torch::lazy::Shape& TorchMlirComputation::result_shape() const { +const torch::lazy::Shape &TorchMlirComputation::result_shape() const { throw std::runtime_error( "todo(whc) implement ts computation shapes or change interface"); return result_shape_; @@ -411,13 +405,9 @@ MlirOperation TorchMlirComputation::func_op() const { return mlirBlockGetFirstOperation(block); } -MlirModule TorchMlirComputation::module_op() const { - return module_op_; -} +MlirModule TorchMlirComputation::module_op() const { return module_op_; } -MlirContext TorchMlirComputation::mlir_context() const { - return mlir_context_; -} +MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; } const std::string TorchMlirComputation::debug_string() const { std::stringstream ss; @@ -430,7 +420,7 @@ const std::string TorchMlirComputation::debug_string() const { // Parameter names ss << "Parameter names:\n"; - for (auto& p : parameter_names_) { + for (auto &p : parameter_names_) { ss << " " << p << "\n"; } ss << "\n"; @@ -451,10 +441,10 @@ const std::string TorchMlirComputation::debug_string() const { const std::string TorchMlirComputation::to_string() const { // Since we use the C-MLIR API, we need to use a callback to print. - MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { + MlirStringCallback print_callback = [](MlirStringRef part, void *user_data) { // user_data is a void ptr to some data structure of our choice -- in this // case, the string stream where we'll be accumulating the strings. - std::stringstream* ss_ptr = static_cast(user_data); + std::stringstream *ss_ptr = static_cast(user_data); *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; @@ -462,7 +452,8 @@ const std::string TorchMlirComputation::to_string() const { // Setup flags for MLIR serialization. MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false); - mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss); + mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, + print_callback, &ss); return ss.str(); } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h index f62a71ce7945..e69820535cb8 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h @@ -39,35 +39,34 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { }; using InputOutputAliases = std::vector; - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device); - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device, - c10::ArrayRef post_order, - torch::lazy::Util::EmissionMap emit_status); + TorchMlirLoweringContext(const std::string &name, + torch::lazy::BackendDevice device); + TorchMlirLoweringContext(const std::string &name, + torch::lazy::BackendDevice device, + c10::ArrayRef post_order, + torch::lazy::Util::EmissionMap emit_status); - void Lower(const Node* node); + void Lower(const Node *node); // Adds a new input/output alias. - void SetUpAlias( - const std::vector& output_index, int64_t param_number, - const std::vector& param_index, - bool must_alias = false) override; + void SetUpAlias(const std::vector &output_index, + int64_t param_number, const std::vector ¶m_index, + bool must_alias = false) override; // Check if parameter shape matches result at index. - bool CheckResultShape( - const BackendDataPtr& parameter_data, size_t result_idx) override; + bool CheckResultShape(const BackendDataPtr ¶meter_data, + size_t result_idx) override; // Adds the given output as a component of the result tuple and returns its // assigned position within the tuple. - size_t AddResult(const torch::lazy::Output& output) override; + size_t AddResult(const torch::lazy::Output &output) override; // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. - void AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) override; + void AddParameter(const torch::lazy::Output &output, size_t index, + const torch::lazy::Shape &shape, + const std::string &name) override; // Build the computation capturing all the operations created with the // embedded builder (returned by the builder() API). @@ -78,27 +77,27 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { // Retrieves the lowered operation for an output. If the requested output is // not available yet, the graph behind the output's Node is lowered, and the // corresponding TS operation returned. - torch::jit::Value* GetOutputOp(const Output& output); + torch::jit::Value *GetOutputOp(const Output &output); // Assigns the given TS operation to the specified output. As outputs are // lowered in a post-order fashion, later nodes should always find their // operands among the emitted outputs. - void AssignOutputOp(const Output& output, torch::jit::Value* op); + void AssignOutputOp(const Output &output, torch::jit::Value *op); // If a parameter associated with data has already been declared, it will be // returned. Otherwise a new one will be created, associated with the tensor // held in data. - torch::jit::Value* GetParameter(BackendDataPtr data); + torch::jit::Value *GetParameter(BackendDataPtr data); std::shared_ptr graph() const; protected: struct Parameter { - torch::jit::Value* param; + torch::jit::Value *param; size_t index = 0; }; - size_t AddResult(torch::jit::Value* op); + size_t AddResult(torch::jit::Value *op); // Creates a jit::Function from the current jit::Graph. Input and output // type information is patched to include shape. @@ -113,8 +112,8 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { MlirContext mlir_context_; std::unordered_map parameters_map_; std::unordered_map parameter_names_; - std::vector root_tuple_; - OutputMap emitted_outputs_; + std::vector root_tuple_; + OutputMap emitted_outputs_; }; class TORCH_API TorchMlirComputation : public torch::lazy::Computation { @@ -122,21 +121,20 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases; using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias; - TorchMlirComputation( - MlirModule module_op, MlirContext mlir_context, - const std::shared_ptr& graph, - std::unordered_map parameters_map, - InputOutputAliases input_output_aliases); + TorchMlirComputation(MlirModule module_op, MlirContext mlir_context, + const std::shared_ptr &graph, + std::unordered_map parameters_map, + InputOutputAliases input_output_aliases); int parameters_size() const override; - const std::vector& parameter_shapes() const override; + const std::vector ¶meter_shapes() const override; - const std::vector& parameter_names() const override; + const std::vector ¶meter_names() const override; - const std::unordered_map& parameters_map() const; + const std::unordered_map ¶meters_map() const; - const torch::lazy::Shape& result_shape() const override; + const torch::lazy::Shape &result_shape() const override; std::shared_ptr graph() const; @@ -152,15 +150,14 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { protected: size_t num_parameters_; - std::unordered_map parameters_map_; - std::vector parameter_names_; - std::vector parameter_shapes_; - Shape result_shape_; - MlirModule module_op_; MlirContext mlir_context_; std::shared_ptr graph_; InputOutputAliases input_output_aliases_; + std::unordered_map parameters_map_; + std::vector parameter_names_; + std::vector parameter_shapes_; + Shape result_shape_; }; } // namespace lazy diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp index 7d9fe056dc30..af680f224095 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -10,8 +10,8 @@ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp //===----------------------------------------------------------------------===// -#include #include +#include #include #include #include @@ -33,16 +33,16 @@ #include "generated/LazyIr.h" #include "generated/LazyNativeFunctions.h" #include "generated/shape_inference.h" -#include "ops/to_copy.h" -#include "ops/unbind_int.h" -#include "ops/split.h" #include "ops/index.h" #include "ops/ivalue.h" +#include "ops/split.h" +#include "ops/to_copy.h" +#include "ops/unbind_int.h" #include "utils/exception.h" #include "utils/sys_utils.h" namespace { -at::Tensor to_meta(const at::Tensor& tensor) { +at::Tensor to_meta(const at::Tensor &tensor) { // undefined tensors can't be converted to the meta device, since they don't // have sizes/strides if (!tensor.defined()) @@ -60,26 +60,27 @@ at::Tensor to_meta(const at::Tensor& tensor) { return out; } -c10::optional to_meta(const c10::optional& tensor) { +c10::optional to_meta(const c10::optional &tensor) { if (tensor.has_value()) { return to_meta(*tensor); } return c10::nullopt; } -std::vector to_meta(at::ITensorListRef t_list) { +[[maybe_unused]] std::vector to_meta(at::ITensorListRef t_list) { std::vector outs; outs.reserve(t_list.size()); - for (const auto& tensor : t_list) { + for (const auto &tensor : t_list) { outs.push_back(to_meta(tensor)); } return outs; } -c10::List> to_meta(const c10::List>& t_list) { +c10::List> +to_meta(const c10::List> &t_list) { c10::List> outs; outs.reserve(t_list.size()); - for (const auto& tensor : t_list) { + for (const auto &tensor : t_list) { outs.push_back(to_meta(tensor)); } return outs; @@ -91,9 +92,9 @@ namespace lazy { namespace { -at::Tensor CreateLtcTensor( - const at::Tensor& tensor, - const c10::optional& device) { +[[maybe_unused]] at::Tensor +CreateLtcTensor(const at::Tensor &tensor, + const c10::optional &device) { if (tensor.defined() && device) { return torch::lazy::CreateAtenFromLtcTensor( torch::lazy::LazyTensor::Create(tensor, *device)); @@ -101,8 +102,8 @@ at::Tensor CreateLtcTensor( return tensor; } -c10::optional -GetLtcDevice(const c10::optional& device) { +[[maybe_unused]] c10::optional +GetLtcDevice(const c10::optional &device) { if (!device) { return c10::nullopt; } @@ -112,24 +113,23 @@ GetLtcDevice(const c10::optional& device) { return torch::lazy::atenDeviceToBackendDevice(*device); } -torch::lazy::Value MaybeExpand( - const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) { +torch::lazy::Value MaybeExpand(const torch::lazy::Value &input, + const torch::lazy::Shape &target_shape) { if (input.shape().sizes() == target_shape.sizes()) { return input; } - return torch::lazy::MakeExpand( - input, target_shape.sizes().vec(), - /*is_scalar_expand=*/false); + return torch::lazy::MakeExpand(input, target_shape.sizes().vec(), + /*is_scalar_expand=*/false); } -void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { +void copy_(torch::lazy::LazyTensorPtr &input, torch::lazy::LazyTensorPtr &src) { if (input->GetDevice() == src->GetDevice()) { torch::lazy::Value copy_value; if (input->dtype() == src->dtype()) { copy_value = src->GetIrValue(); } else { - copy_value = torch::lazy::MakeCast( - src->GetIrValue(), input->dtype(), src->dtype()); + copy_value = torch::lazy::MakeCast(src->GetIrValue(), input->dtype(), + src->dtype()); } input->SetIrValue(MaybeExpand(copy_value, input->shape())); } else { @@ -146,15 +146,17 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { // clone is special in LT because we make it a no-op. // This should be safe to do, because every operator in the LT is functional. -at::Tensor LazyNativeFunctions::clone( - const at::Tensor& self, c10::optional memory_format) { +at::Tensor +LazyNativeFunctions::clone(const at::Tensor &self, + c10::optional memory_format) { auto self_lt = torch::lazy::TryGetLtcTensor(self); return torch::lazy::CreateAtenFromLtcTensor( self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice())); } -at::Tensor LazyNativeFunctions::_copy_from( - const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { +at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor &self, + const at::Tensor &dst, + bool non_blocking) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -199,16 +201,16 @@ at::Tensor LazyNativeFunctions::_copy_from( } } else { copy_(dst_tensor, self_tensor); - auto* impl = - dynamic_cast(dst.unsafeGetTensorImpl()); + auto *impl = + dynamic_cast(dst.unsafeGetTensorImpl()); impl->set_tensor(dst_tensor); } } return dst; } -at::Tensor LazyNativeFunctions::_copy_from_and_resize( - const at::Tensor& self, const at::Tensor& dst) { +at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor &self, + const at::Tensor &dst) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -223,8 +225,8 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize( dst.resize_as_(typed_tensor).copy_(typed_tensor); } else { // at this point we know dst is a lazy tensor - auto* dest_impl = - dynamic_cast(dst.unsafeGetTensorImpl()); + auto *dest_impl = + dynamic_cast(dst.unsafeGetTensorImpl()); dest_impl->tensor()->UpdateFromTensorOut(self_tensor); dest_impl->force_refresh_sizes(); } @@ -232,15 +234,16 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize( } at::Tensor LazyNativeFunctions::_to_copy( - const at::Tensor& self, c10::optional dtype, + const at::Tensor &self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format) { PRINT_FUNCTION(); auto options = self.options(); if (dtype) { - // I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)... - // because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it + // I put each of these setters in a conditional instead of doing + // `self.options().dtype(dtype).layout(layout)... because calling + // .dtype(nullopt) on an options() that already has dtype appears to wipe it options = options.dtype(dtype); } if (layout) { @@ -261,8 +264,9 @@ at::Tensor LazyNativeFunctions::_to_copy( if (!lazy_self && device && device->type() == c10::kLazy) { // Case 1: eager->lazy (we create a new lazy tensor) // See Note [Lazy Tensor Functionalization] - // Invariant: if the functionalization key is in the exclude set, then we're expected - // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. + // Invariant: if the functionalization key is in the exclude set, then we're + // expected to return an ordinary tensor, which will be "lifted" into a + // functional wrapper later. bool functionalize_output = !c10::impl::tls_local_dispatch_key_set().excluded_.has( c10::DispatchKey::Functionalize); @@ -270,7 +274,8 @@ at::Tensor LazyNativeFunctions::_to_copy( self, options, *device, /*non_blocking=*/non_blocking, /*functionalize_output=*/functionalize_output); } else if (device && device->type() != c10::kLazy) { - // Case 2: lazy->eager (forces a graph break since we are materializing a tensor) + // Case 2: lazy->eager (forces a graph break since we are materializing a + // tensor) TORCH_INTERNAL_ASSERT(lazy_self); auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); @@ -278,22 +283,24 @@ at::Tensor LazyNativeFunctions::_to_copy( auto moved_eager_tensor = eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); return moved_eager_tensor; - } else if ( - device && device->type() == c10::kLazy && device->has_index() && - device->index() != self.device().index()) { + } else if (device && device->type() == c10::kLazy && device->has_index() && + device->index() != self.device().index()) { // Case 3: lazy:0 -> lazy:1 // TODO(whc) what do we actually want to do here? // option 1: materialize, move eager tensor, create new lazy tensor - // - this should be our default, as it is what would happen before we implemented _to_copy + // - this should be our default, as it is what would happen before we + // implemented _to_copy // - actually combines case 1 + case 2 // option 2: support multiple devices inside one lazy/TS executor (case 4) - // - but: we may have other assumptions that there is just one device per executor? so don't take this lightly + // - but: we may have other assumptions that there is just one device + // per executor? so don't take this lightly TORCH_INTERNAL_ASSERT(lazy_self); auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); // we move the eager tensor to the 'eager' equivalent of our lazy device - // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use + // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is + // what we use auto eager_device = c10::Device( torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); options = options.device(eager_device); @@ -305,12 +312,14 @@ at::Tensor LazyNativeFunctions::_to_copy( return torch::lazy::CreateAtenFromLtcTensor(lazy_self); } else { - // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph) - - // Note: captured _to_copy will be executed with real eager tensors, not lazy tensors. - // We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to - // convert an eager tensor back to a lazy one inside the torchscript executor - // lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument + // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy + // graph) + + // Note: captured _to_copy will be executed with real eager tensors, not + // lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this + // captured IR, or we will try to convert an eager tensor back to a lazy one + // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so + // we can safely drop the device argument device = c10::nullopt; auto shapes = torch::lazy::compute_shape__to_copy( @@ -325,259 +334,299 @@ at::Tensor LazyNativeFunctions::_to_copy( std::move(node), lazy_self->GetDevice())); return result; } -}; +} -at::Tensor LazyNativeFunctions::_unsafe_view( - const at::Tensor& self, at::IntArrayRef size) { +at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor &self, + at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); - return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size)); + return LazyNativeFunctions::view_copy_symint(self, + c10::fromIntArrayRefSlow(size)); } -at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { +at::Tensor LazyNativeFunctions::t(const at::Tensor &self) { TORCH_LAZY_FN_COUNTER("lazy::"); return at::functionalization::functionalize_aten_op::call(self); } -std::vector LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) { +std::vector LazyNativeFunctions::unbind_copy(const at::Tensor &self, + int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), dim); + + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim); - + auto out_meta = + at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim); + std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, dim }; - const char* schema_str = "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, dim}; + const char *schema_str = + "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), dim, std::move(shapes)); + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), dim, + std::move(shapes)); CacheNode(node); } - + std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -std::vector LazyNativeFunctions::split_with_sizes_copy_symint(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { +std::vector LazyNativeFunctions::split_with_sizes_copy_symint( + const at::Tensor &self, c10::SymIntArrayRef split_sizes, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode( + lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::split_with_sizes_copy_symint(self_meta, split_sizes, dim); + auto out_meta = at::compositeexplicitautogradnonfunctional:: + split_with_sizes_copy_symint(self_meta, split_sizes, dim); std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, split_sizes, dim }; - const char* schema_str = "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]"; - applySymbolicShapesOnLT(schema_str, inputs, shapes); + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, split_sizes, dim}; + const char *schema_str = "aten::split_with_sizes_copy(Tensor self, " + "SymInt[] split_sizes, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, std::move(shapes)); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, + std::move(shapes)); CacheNode(node); } std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -std::vector LazyNativeFunctions::split_copy_symint(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { +std::vector +LazyNativeFunctions::split_copy_symint(const at::Tensor &self, + c10::SymInt split_size, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode( + lazy_self->GetIrValue(), GetSymIntValue(split_size), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::split_copy_symint(self_meta, split_size, dim); + auto out_meta = + at::compositeexplicitautogradnonfunctional::split_copy_symint( + self_meta, split_size, dim); std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } const size_t num_outputs = shapes.size(); - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, split_size, dim }; - const char* schema_str = "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"; - applySymbolicShapesOnLT(schema_str, inputs, shapes); + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, split_size, dim}; + const char *schema_str = "aten::split_copy.Tensor(Tensor self, SymInt " + "split_size, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, std::move(shapes), num_outputs); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, + std::move(shapes), num_outputs); CacheNode(node); } std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List> & indices) { +at::Tensor LazyNativeFunctions::index( + const at::Tensor &self, + const c10::List> &indices) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); std::vector values; - for (const auto & it : indices) { + for (const auto &it : indices) { c10::optional tensor = it; - LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); - values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + LazyTensorPtr lazy_tensor = + torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + values.push_back( + lazy_tensor + ? lazy_tensor->GetIrValue() + : torch::lazy::Value(MakeNode(c10::IValue()), 0)); } auto list = MakeNode(values); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), list); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), list); if (!node) { auto self_meta = to_meta(self); auto indices_meta = to_meta(indices); auto out_meta = at::meta::index(self_meta, indices_meta); - std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + std::vector shapes{ + torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; TORCH_INTERNAL_ASSERT(shapes.size() == 1); - if(torch::lazy::symbolicShapeEnabled()) { - std::vector inputs = { self, indices }; - const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, indices}; + const char *schema_str = + "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), list, std::move(shapes)); + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), list, + std::move(shapes)); CacheNode(node); } auto result = torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); return result; } -at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate) { +at::Tensor LazyNativeFunctions::index_put( + const at::Tensor &self, const c10::List> &indices, + const at::Tensor &values, bool accumulate) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + LazyTensorPtr lazy_valeus = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device); std::vector indices_vector; - for (const auto & it : indices) { + for (const auto &it : indices) { c10::optional tensor = it; - LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); - indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + LazyTensorPtr lazy_tensor = + torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + indices_vector.push_back( + lazy_tensor + ? lazy_tensor->GetIrValue() + : torch::lazy::Value(MakeNode(c10::IValue()), 0)); } auto indices_list = MakeNode(indices_vector); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), indices_list, + lazy_valeus->GetIrValue(), accumulate); if (!node) { auto self_meta = to_meta(self); auto indices_meta = to_meta(indices); auto values_meta = to_meta(values); - auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate); + auto out_meta = at::compositeexplicitautograd::index_put( + self_meta, indices_meta, values_meta, accumulate); - std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + std::vector shapes{ + torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; TORCH_INTERNAL_ASSERT(shapes.size() == 1); - if(torch::lazy::symbolicShapeEnabled()) { - std::vector inputs = { self, indices, values }; - const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, indices, values}; + const char *schema_str = + "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool " + "accumulate=False) -> Tensor"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes)); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), + accumulate, std::move(shapes)); CacheNode(node); } auto result = torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); return result; } // This is needed by the torch.tensor constructor. // LazyTensor always opts into functionalization. -// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object. -at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) { +// "lifting" a tensor for functionalization means wrapping it in a +// FunctionalTensorWrapper object. +at::Tensor LazyNativeFunctions::lift(const at::Tensor &tensor) { TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return at::functionalization::impl::to_functional_tensor(tensor); } -at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) { +at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor &tensor) { TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return at::functionalization::impl::to_functional_tensor(tensor); } -// All of the below ops correspond to CompositeExplicitAutograd kernels from core -// that call into view operators internally. -// These are all composite ops that LTC can technically re-use / get for free, -// but we need to "functionalize" them to remove the view ops before we can use them. +// All of the below ops correspond to CompositeExplicitAutograd kernels from +// core that call into view operators internally. These are all composite ops +// that LTC can technically re-use / get for free, but we need to +// "functionalize" them to remove the view ops before we can use them. at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { return at::functionalization::functionalize_aten_op::call(tensors); } at::Tensor LazyNativeFunctions::new_empty_strided_symint( - const at::Tensor& self, - c10::SymIntArrayRef size, - c10::SymIntArrayRef stride, - c10::optional dtype, - c10::optional layout, - c10::optional device, + const at::Tensor &self, c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, c10::optional dtype, + c10::optional layout, c10::optional device, c10::optional pin_memory) { if (!device || device->type() == c10::DeviceType::Lazy) { - return at::functionalization::functionalize_aten_op_symint< - ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout, - device, pin_memory); + return at::functionalization::functionalize_aten_op_symint::call(self, size, stride, dtype, layout, device, + pin_memory); } - // For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu") - // we need to avoid explicit functionalization. To do that we create regular cpu tensors. + // For cases when device != lazy, for example: + // lazy_tensor.new_empty_strided(..., "cpu") we need to avoid explicit + // functionalization. To do that we create regular cpu tensors. at::Tensor t = at::empty_symint( size, (dtype ? dtype : c10::optional(self.scalar_type())), (layout ? layout : c10::optional(self.layout())), device, @@ -585,65 +634,63 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint( return t.as_strided_symint(size, stride, /*storage_offset=*/0); } -at::Tensor LazyNativeFunctions::narrow_copy_symint( - const at::Tensor& self, - int64_t dim, - c10::SymInt start, - c10::SymInt length) { +at::Tensor LazyNativeFunctions::narrow_copy_symint(const at::Tensor &self, + int64_t dim, + c10::SymInt start, + c10::SymInt length) { return at::functionalization::functionalize_aten_op_symint::call(self, dim, start, length); } -at::Tensor LazyNativeFunctions::pixel_shuffle( - const at::Tensor& self, int64_t upscale_factor) { +at::Tensor LazyNativeFunctions::pixel_shuffle(const at::Tensor &self, + int64_t upscale_factor) { return at::functionalization::functionalize_aten_op::call(self, upscale_factor); } -at::Tensor LazyNativeFunctions::pixel_unshuffle( - const at::Tensor& self, int64_t downscale_factor) { +at::Tensor LazyNativeFunctions::pixel_unshuffle(const at::Tensor &self, + int64_t downscale_factor) { return at::functionalization::functionalize_aten_op::call(self, downscale_factor); } -at::Tensor LazyNativeFunctions::select_backward( - const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim, - int64_t index) { +at::Tensor LazyNativeFunctions::select_backward(const at::Tensor &grad_output, + at::IntArrayRef input_sizes, + int64_t dim, int64_t index) { return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, index); } at::Tensor LazyNativeFunctions::slice_backward_symint( - const at::Tensor& grad_output, - at::SymIntArrayRef input_sizes, - int64_t dim, - c10::SymInt start, - c10::SymInt end, - c10::SymInt step) { + const at::Tensor &grad_output, at::SymIntArrayRef input_sizes, int64_t dim, + c10::SymInt start, c10::SymInt end, c10::SymInt step) { return at::functionalization::functionalize_aten_op_symint::call(grad_output, input_sizes, dim, start, end, step); } -at::Tensor LazyNativeFunctions::diagonal_backward( - const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t offset, - int64_t dim1, int64_t dim2) { +at::Tensor LazyNativeFunctions::diagonal_backward(const at::Tensor &grad_output, + at::IntArrayRef input_sizes, + int64_t offset, int64_t dim1, + int64_t dim2) { return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, offset, dim1, dim2); } at::Tensor LazyNativeFunctions::_trilinear( - const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3, + const at::Tensor &i1, const at::Tensor &i2, const at::Tensor &i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim) { - return at::functionalization::functionalize_aten_op:: - call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); + return at::functionalization::functionalize_aten_op::call(i1, i2, i3, expand1, expand2, expand3, sumdim, + unroll_dim); } at::Tensor LazyNativeFunctions::linalg_pinv( - const at::Tensor& self, const c10::optional& atol, - const c10::optional& rtol, bool hermitian) { + const at::Tensor &self, const c10::optional &atol, + const c10::optional &rtol, bool hermitian) { return at::functionalization::functionalize_aten_op::call(self, atol, rtol, hermitian); } // functionalize_aten_op can't handle out= ops directly. -// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs. -at::Tensor& LazyNativeFunctions::logsumexp_out( - const at::Tensor& self, at::IntArrayRef dim, bool keepdim, - at::Tensor& out) { +// Instead, we can call the composite kernel from core, and copy and mutations +// back to the inputs. +at::Tensor &LazyNativeFunctions::logsumexp_out(const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim, at::Tensor &out) { auto self_wrapped = at::functionalization::impl::to_functional_tensor(self); auto out_wrapped = at::functionalization::impl::to_functional_tensor(out); // directly call the composite kernel from core. diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp index 39dc1ad0cd58..0f31fab2c1e0 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp @@ -18,11 +18,10 @@ namespace lazy { namespace { -hash_t OperandHashes( - const OpList& operands, const c10::ArrayRef& shapes, - const hash_t& seed, bool bakeInSizes) { +hash_t OperandHashes(const OpList &operands, const c10::ArrayRef &shapes, + const hash_t &seed, bool bakeInSizes) { hash_t hash = seed; - for (auto& operand : operands) { + for (auto &operand : operands) { if (!operand) { hash = HashCombine(hash, static_cast(kNullOpt)); continue; @@ -30,7 +29,7 @@ hash_t OperandHashes( auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash(); hash = HashCombine(hash, operand_hash); } - for (auto& shape : shapes) { + for (auto &shape : shapes) { hash = HashCombine(hash, shape.hash(bakeInSizes)); } return hash; @@ -38,53 +37,51 @@ hash_t OperandHashes( } // namespace - -// Adds a static hook that is run after every single TorchMlirNode is initialized -static std::vector> constructor_hooks; -void TorchMlirNode::addConstructorHook(std::function f) { +// Adds a static hook that is run after every single TorchMlirNode is +// initialized +static std::vector> constructor_hooks; +void TorchMlirNode::addConstructorHook(std::function f) { constructor_hooks.emplace_back(f); } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, - hash_t hash_seed) +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, + std::vector &&shapes, size_t num_outputs, + hash_t hash_seed) : Node(op, operands, std::move(shapes), num_outputs) { hash_seed = HashCombine(op.hash(), hash_seed); shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true); - dag_hash_ = - (enableDynamicShape() - ? OperandHashes(operands, this->shapes(), hash_seed, false) - : shape_hash_); + dag_hash_ = (enableDynamicShape() + ? OperandHashes(operands, this->shapes(), hash_seed, false) + : shape_hash_); - for (std::function& f : constructor_hooks) { + for (std::function &f : constructor_hooks) { f(this); } } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed) - : TorchMlirNode( - op, operands, std::vector{}, num_outputs, hash_seed) { +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, + const std::function &shape_fn, + size_t num_outputs, hash_t hash_seed) + : TorchMlirNode(op, operands, std::vector{}, num_outputs, + hash_seed) { addComputedShape(shape_fn); } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) - : TorchMlirNode( - op, operands, std::vector{}, num_outputs, hash_seed) {} +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, size_t num_outputs, + hash_t hash_seed) + : TorchMlirNode(op, operands, std::vector{}, num_outputs, + hash_seed) {} -TorchMlirNode::TorchMlirNode( - OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) +TorchMlirNode::TorchMlirNode(OpKind op, Shape shape, size_t num_outputs, + hash_t hash_seed) : TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {} hash_t TorchMlirNode::hash() const { return dag_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } - -TorchMlirNode* TorchMlirNode::mlir_node(int index) const { - return dynamic_cast(operands_.at(index).get()); +TorchMlirNode *TorchMlirNode::mlir_node(int index) const { + return dynamic_cast(operands_.at(index).get()); } /////////////////////////////////////////////////////////////////////////////// @@ -107,11 +104,12 @@ TorchMlirTensorList::TorchMlirTensorList(OpList values) /*num_outputs=*/1, /*hash_seed=*/kHashSeed) {} -torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - std::vector tensor_list; +torch::lazy::TorchMlirOpVector +TorchMlirTensorList::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { + std::vector tensor_list; CHECK(!operands().empty()); - for (const torch::lazy::Output& operand : operands()) { + for (const torch::lazy::Output &operand : operands()) { tensor_list.emplace_back(loctx->GetOutputOp(operand)); } auto graph = function->graph(); @@ -140,16 +138,17 @@ TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values) /*num_outputs=*/1, /*hash_seed=*/kHashSeed) {} -torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - std::vector tensor_list; +torch::lazy::TorchMlirOpVector +TorchMlirOptionalTensorList::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { + std::vector tensor_list; CHECK(!operands().empty()); - for (const torch::lazy::Output& operand : operands()) { + for (const torch::lazy::Output &operand : operands()) { tensor_list.emplace_back(loctx->GetOutputOp(operand)); } auto graph = function->graph(); - auto listnode = - graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list)); + auto listnode = graph->insertNode(graph->createList( + c10::OptionalType::create(c10::TensorType::get()), tensor_list)); return {listnode->output()}; } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node.h b/projects/ltc/csrc/base_lazy_backend/mlir_node.h index a76ec0b05064..e5738a92176d 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node.h @@ -27,23 +27,22 @@ namespace lazy { class TORCH_API TorchMlirNode : public torch::lazy::Node { public: - TorchMlirNode( - OpKind op, OpList operands, std::vector&& shapes, - size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, std::vector &&shapes, + size_t num_outputs, hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, + const std::function &shape_fn, size_t num_outputs, + hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, OpList operands, size_t num_outputs, - hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, size_t num_outputs, + hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, Shape shape, size_t num_outputs, + hash_t hash_seed = kHashSeed); - // Adds a static hook that is run after every single TorchMlirNode is constructed - static void addConstructorHook(std::function); + // Adds a static hook that is run after every single TorchMlirNode is + // constructed + static void addConstructorHook(std::function); ~TorchMlirNode() override = default; @@ -51,10 +50,10 @@ class TORCH_API TorchMlirNode : public torch::lazy::Node { hash_t shapeHash() const override; - TorchMlirNode* mlir_node(int index) const; + TorchMlirNode *mlir_node(int index) const; - virtual TorchMlirOpVector - Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; + virtual TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const; private: // The hash of the dag WITH size info. Used for shape caching @@ -86,22 +85,23 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode { TorchMlirTensorList() = delete; TorchMlirTensorList(OpList values); - torch::lazy::TorchMlirOpVector Lower( - TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; }; -// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent -// optional tensors, so the output type for this op is !torch.list>. +// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also +// represent optional tensors, so the output type for this op is +// !torch.list>. struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode { static OpKind ClassOpKind(); TorchMlirOptionalTensorList() = delete; TorchMlirOptionalTensorList(OpList values); - torch::lazy::TorchMlirOpVector Lower( - TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; }; } // namespace lazy diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp index a21bb93f0854..b52b724f0f16 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -31,21 +31,23 @@ namespace torch { namespace lazy { -TorchMlirOpVector LowerTorchMlirBuiltin( - TorchMlirFunction function, c10::Symbol sym, - const std::vector tensor_types, - const std::vector& arguments, - const std::vector& kwarguments) { +TorchMlirOpVector +LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym, + const std::vector tensor_types, + const std::vector &arguments, + const std::vector &kwarguments) { // Workaround for ListType::isSubtypeOfExt behavior which leads to // the problems with JIT schema matching, so we need to keep // c10::ListType empty before magic_method->call function call. auto dummy_graph = torch::jit::Graph(); for (auto arg : arguments) { - torch::jit::Value* value = arg.value(dummy_graph); + torch::jit::Value *value = arg.value(dummy_graph); if (value->type()->kind() == c10::TypeKind::ListType) { - auto list_element_type = value->type()->cast()->getElementType(); + auto list_element_type = + value->type()->cast()->getElementType(); if (list_element_type->cast()) { - value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get()))); + value->setType(c10::ListType::create( + c10::OptionalType::create(c10::TensorType::get()))); } else { value->setType(c10::ListType::create(c10::TensorType::get())); } @@ -56,25 +58,27 @@ TorchMlirOpVector LowerTorchMlirBuiltin( std::make_shared(sym, at::nullopt); auto magic_method = std::make_shared("", builtin); auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); - auto sv = dynamic_cast(ret.get()); + auto sv = dynamic_cast(ret.get()); CHECK(sv); TorchMlirOpVector results; if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) { - // Unpack dynamic multi-output operations like aten::split with Tensor[] output type. - // This is required to have consistent input types for multi-output node consumers. - torch::jit::Node * node = function->graph()->createListUnpack(sv->getValue(), tensor_types.size()); + // Unpack dynamic multi-output operations like aten::split with Tensor[] + // output type. This is required to have consistent input types for + // multi-output node consumers. + torch::jit::Node *node = function->graph()->createListUnpack( + sv->getValue(), tensor_types.size()); function->graph()->insertNode(node); - for (const auto & output : node->outputs()) { + for (const auto &output : node->outputs()) { results.push_back(output); } } else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { - // Op returns multiple values and the number of outputs is static and defined - // by the operation schema. + // Op returns multiple values and the number of outputs is static and + // defined by the operation schema. const auto tuple_call_result = sv->asTuple({}, *function); - for (const auto& tuple_component : tuple_call_result) { + for (const auto &tuple_component : tuple_call_result) { auto tuple_component_sv = - dynamic_cast(tuple_component.get()); + dynamic_cast(tuple_component.get()); results.push_back(tuple_component_sv->getValue()); } } else { @@ -84,7 +88,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin( // Insert known tensor type information. unsigned tensor_type_idx = 0; - for (jit::Value* value : results) { + for (jit::Value *value : results) { if (value->type()->kind() == c10::TypeKind::TensorType) { TORCH_CHECK( tensor_type_idx < tensor_types.size(), function->graph()->toString(), @@ -97,23 +101,22 @@ TorchMlirOpVector LowerTorchMlirBuiltin( } // Ensure that we use up all the known tensor type information available. - TORCH_CHECK( - tensor_type_idx == tensor_types.size(), tensor_type_idx, - " known types were injected into jit::Value, but ", tensor_types.size(), - " were provided from lazy::Node!"); + TORCH_CHECK(tensor_type_idx == tensor_types.size(), tensor_type_idx, + " known types were injected into jit::Value, but ", + tensor_types.size(), " were provided from lazy::Node!"); return results; } -TorchMlirOpVector LowerTorchMlirBuiltin( - TorchMlirFunction function, c10::Symbol sym, - const c10::ArrayRef result_shapes, - const std::vector& arguments, - const std::vector& kwarguments) { +TorchMlirOpVector +LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym, + const c10::ArrayRef result_shapes, + const std::vector &arguments, + const std::vector &kwarguments) { std::vector tensor_types; // Generate types with fixed tensor shape information. - for (const Shape& shape : result_shapes) { + for (const Shape &shape : result_shapes) { tensor_types.push_back(torch::jit::TensorType::create( /*scalar_type=*/shape.scalar_type(), /*device=*/c10::nullopt, @@ -122,34 +125,34 @@ TorchMlirOpVector LowerTorchMlirBuiltin( /*requires_grad=*/c10::nullopt)); } - return LowerTorchMlirBuiltin( - function, sym, tensor_types, arguments, kwarguments); + return LowerTorchMlirBuiltin(function, sym, tensor_types, arguments, + kwarguments); } -TorchMlirOpVector LowerBuiltin( - const torch::lazy::Node* node, TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function, node->op().op, node->shapes(), arguments, kwarguments); +TorchMlirOpVector +LowerBuiltin(const torch::lazy::Node *node, TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { + return LowerTorchMlirBuiltin(function, node->op().op, node->shapes(), + arguments, kwarguments); } -TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const c10::ArrayRef result_shapes, - TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function, sym, result_shapes, arguments, kwarguments); +TorchMlirOpVector +LowerBuiltin(c10::Symbol sym, const c10::ArrayRef result_shapes, + TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { + return LowerTorchMlirBuiltin(function, sym, result_shapes, arguments, + kwarguments); } -TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const std::vector types, - TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { +TorchMlirOpVector +LowerBuiltin(c10::Symbol sym, const std::vector types, + TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments); } -c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { +c10::TensorType &cast_tensor_type(c10::TypePtr value_type) { auto tensor_type = value_type->cast(); TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!"); @@ -157,8 +160,8 @@ c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { } c10::optional> -get_tensor_type_shape(c10::TensorType& tensor_type) { - auto& symbolic_shape = tensor_type.symbolic_sizes(); +get_tensor_type_shape(c10::TensorType &tensor_type) { + auto &symbolic_shape = tensor_type.symbolic_sizes(); if (!symbolic_shape.rank()) { return c10::nullopt; } @@ -175,21 +178,21 @@ get_tensor_type_shape(c10::TensorType& tensor_type) { } std::vector compute_shape_copy(c10::TypePtr value_type) { - c10::TensorType& tensor_type = cast_tensor_type(value_type); + c10::TensorType &tensor_type = cast_tensor_type(value_type); auto maybe_dims = get_tensor_type_shape(tensor_type); TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!"); auto scalar_type = tensor_type.scalarType(); - TORCH_CHECK( - scalar_type.has_value(), "Unable to copy due to lack of scalar type!"); + TORCH_CHECK(scalar_type.has_value(), + "Unable to copy due to lack of scalar type!"); return {Shape(scalar_type.value(), maybe_dims.value())}; } -std::vector compute_shape_slice( - c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end, - int64_t step) { - c10::TensorType& tensor_type = cast_tensor_type(value_type); +std::vector compute_shape_slice(c10::TypePtr value_type, + int64_t dim, int64_t start, + int64_t end, int64_t step) { + c10::TensorType &tensor_type = cast_tensor_type(value_type); auto maybe_dims = get_tensor_type_shape(tensor_type); TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!"); @@ -217,13 +220,13 @@ std::vector compute_shape_slice( } auto scalar_type = tensor_type.scalarType(); - TORCH_CHECK( - scalar_type.has_value(), "Unable to slice due to lack of scalar type!"); + TORCH_CHECK(scalar_type.has_value(), + "Unable to slice due to lack of scalar type!"); return {Shape(scalar_type.value(), dims)}; } -torch::jit::Value* -GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { +torch::jit::Value *GenerateClone(torch::jit::Value *val, + TorchMlirFunction function) { std::vector clone_arguments; clone_arguments.emplace_back(val); @@ -234,20 +237,19 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { return cloned.front(); } -void GenerateCopy( - torch::jit::Value* destination, torch::jit::Value* source, - TorchMlirFunction function) { +void GenerateCopy(torch::jit::Value *destination, torch::jit::Value *source, + TorchMlirFunction function) { std::vector arguments; arguments.emplace_back(destination); arguments.emplace_back(source); - LowerBuiltin( - at::aten::copy_, c10::ArrayRef(compute_shape_copy(source->type())), - function, arguments); + LowerBuiltin(at::aten::copy_, + c10::ArrayRef(compute_shape_copy(source->type())), + function, arguments); } -torch::jit::Value* GenerateSlice( - torch::jit::Value* base, int64_t dim, int64_t start, int64_t end, - int64_t step, TorchMlirFunction function) { +torch::jit::Value *GenerateSlice(torch::jit::Value *base, int64_t dim, + int64_t start, int64_t end, int64_t step, + TorchMlirFunction function) { std::vector arguments; arguments.emplace_back(base); arguments.emplace_back(dim); @@ -255,11 +257,11 @@ torch::jit::Value* GenerateSlice( arguments.emplace_back(end); arguments.emplace_back(step); - TorchMlirOpVector selected = LowerBuiltin( - at::aten::slice, - c10::ArrayRef( - compute_shape_slice(base->type(), dim, start, end, step)), - function, arguments); + TorchMlirOpVector selected = + LowerBuiltin(at::aten::slice, + c10::ArrayRef(compute_shape_slice(base->type(), dim, + start, end, step)), + function, arguments); TORCH_CHECK_EQ(selected.size(), 1); return selected.front(); } @@ -267,10 +269,10 @@ torch::jit::Value* GenerateSlice( // Node Lowerings // Default Node Lowering -TorchMlirOpVector TorchMlirNode::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector TorchMlirNode::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; - for (const torch::lazy::Output& output : operands()) { + for (const torch::lazy::Output &output : operands()) { arguments.emplace_back(loctx->GetOutputOp(output)); } return LowerBuiltin(this, function, arguments); @@ -280,19 +282,19 @@ TorchMlirOpVector TorchMlirNode::Lower( // Non-native nodes -TorchMlirOpVector -Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Cast::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(dtype); return LowerBuiltin(at::aten::to, shapes(), function, arguments); } -TorchMlirOpVector DeviceData::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector DeviceData::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { auto infoptr = data_->info(); auto deviceDataInfoPtr = - (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; + (torch::lazy::LazyGraphExecutor::DeviceDataInfo *)infoptr; if (GRAPH_DUMP_ENABLED) { LOG(ERROR) << "Lowering device data node, tensor id " << deviceDataInfoPtr->tensor_id << std::endl; @@ -300,8 +302,8 @@ TorchMlirOpVector DeviceData::Lower( return {loctx->GetParameter(data_)}; } -TorchMlirOpVector Scalar::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Scalar::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { auto options = at::TensorOptions() .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) @@ -309,8 +311,8 @@ TorchMlirOpVector Scalar::Lower( return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; } -TorchMlirOpVector Expand::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Expand::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(size); diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h index f9e028a5cc15..650bed045c25 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h @@ -18,14 +18,14 @@ namespace torch { namespace lazy { -typedef std::vector TorchMlirOpVector; +typedef std::vector TorchMlirOpVector; typedef std::shared_ptr TorchMlirFunction; TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin( TorchMlirFunction function, c10::Symbol sym, const c10::ArrayRef result_shapes, - const std::vector& arguments, - const std::vector& kwarguments = {}); + const std::vector &arguments, + const std::vector &kwarguments = {}); } // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp b/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp index b4271df6691e..c4255068fcb5 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp @@ -2,18 +2,16 @@ #include -#include "device_data.h" #include "../backend_impl.h" +#include "device_data.h" namespace torch { namespace lazy { DeviceData::DeviceData(std::shared_ptr data) - : TorchMlirNode( - ClassOpKind(), - data->shape(), - /*num_outputs=*/1, - /*hash_seed=*/static_cast(101)), + : TorchMlirNode(ClassOpKind(), data->shape(), + /*num_outputs=*/1, + /*hash_seed=*/static_cast(101)), data_(std::move(data)) { propagate_name(); } @@ -21,9 +19,11 @@ DeviceData::DeviceData(std::shared_ptr data) void DeviceData::propagate_name() { if (data_ && name_ != "") { // Add device data name to backend data - TorchMlirBackendData* mlir_data = dynamic_cast(data_.get()); + TorchMlirBackendData *mlir_data = + dynamic_cast(data_.get()); TORCH_CHECK(mlir_data); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto *info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info); info->name = name_; } @@ -34,7 +34,7 @@ void DeviceData::SetData(std::shared_ptr data) { propagate_name(); } -void DeviceData::SetName(const std::string& name) { +void DeviceData::SetName(const std::string &name) { name_ = name; propagate_name(); } @@ -43,12 +43,12 @@ std::string DeviceData::ToString() const { std::stringstream ss; ss << TorchMlirNode::ToString() << ", device=" << data_->device(); if (name_ != "") { - ss << ", name=" << name_; + ss << ", name=" << name_; } return ss.str(); } -const DeviceData* DeviceData::Cast(const Node* node) { +const DeviceData *DeviceData::Cast(const Node *node) { return NodeCast(node); } @@ -59,7 +59,7 @@ NodePtr DeviceData::Create(std::shared_ptr data) { // Ditching the old data_ is safe because tracing is done iteration // by iteration, and after we lauch the async device execution for the // previous iteration, data_ in DeviceData nodes are not needed anymore. - DeviceData* device_data = static_cast(node.get()); + DeviceData *device_data = static_cast(node.get()); device_data->SetData(data); return node; } diff --git a/projects/ltc/csrc/base_lazy_backend/ops/device_data.h b/projects/ltc/csrc/base_lazy_backend/ops/device_data.h index ad9d9d0eb94b..6f96d074962f 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/device_data.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/device_data.h @@ -6,15 +6,12 @@ #include #include - namespace torch { namespace lazy { class TORCH_API DeviceData : public TorchMlirNode { - public: - static OpKind ClassOpKind() { - return ltc_device_data; - } +public: + static OpKind ClassOpKind() { return ltc_device_data; } explicit DeviceData(std::shared_ptr data); @@ -27,22 +24,23 @@ class TORCH_API DeviceData : public TorchMlirNode { std::string ToString() const override; - const std::shared_ptr& data() const { return data_; } + const std::shared_ptr &data() const { return data_; } void SetData(std::shared_ptr data); - TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override; + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; - static const DeviceData* Cast(const Node* node); + static const DeviceData *Cast(const Node *node); // To reuse IR nodes, use this method to create DeviceData nodes // instead of calling the constructor directly. static NodePtr Create(std::shared_ptr data); - const std::string& GetName() const { return name_; } - void SetName(const std::string& name); + const std::string &GetName() const { return name_; } + void SetName(const std::string &name); - private: +private: void propagate_name(); std::shared_ptr data_; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp b/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp index 1df8be231023..17e578946fb2 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp @@ -15,12 +15,8 @@ namespace torch { namespace lazy { -Generic::Generic( - OpKind op, - OpList operands, - Shape shape, - size_t num_outputs, - hash_t hash_seed) +Generic::Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs, + hash_t hash_seed) : TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed), hash_seed_(hash_seed) {} diff --git a/projects/ltc/csrc/base_lazy_backend/ops/generic.h b/projects/ltc/csrc/base_lazy_backend/ops/generic.h index f294b1cfaed2..01794355a8b4 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/generic.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/generic.h @@ -23,15 +23,11 @@ namespace lazy { // captured by the LowerFn), but they should instead create a dedicated IR node. // Doing the former would limit IR introspection. class TORCH_API Generic : public TorchMlirNode { - public: - Generic( - OpKind op, - OpList operands, - Shape shape, - size_t num_outputs = 1, - hash_t hash_seed = static_cast(0x5a2d296e9)); +public: + Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs = 1, + hash_t hash_seed = static_cast(0x5a2d296e9)); - private: +private: hash_t hash_seed_; }; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/index.cpp b/projects/ltc/csrc/base_lazy_backend/ops/index.cpp index 34af3e590162..ffa2f06bbccf 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/index.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/index.cpp @@ -12,9 +12,9 @@ namespace torch { namespace lazy { -IndexTensor::IndexTensor(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - std::vector&& shapes) +IndexTensor::IndexTensor(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + std::vector &&shapes) : torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(), OpList{self, indices}, std::move(shapes), /* num_outputs */ 1, torch::lazy::MHash()) {} @@ -25,13 +25,13 @@ std::string IndexTensor::ToString() const { return ss.str(); } -bool IndexTensor::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices) const { +bool IndexTensor::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices) const { return false; } TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -49,10 +49,10 @@ TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function, return index_out; } -IndexPut::IndexPut(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate, - std::vector&& shapes) +IndexPut::IndexPut(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate, + std::vector &&shapes) : torch::lazy::TorchMlirNode( IndexPut::ClassOpKind(), OpList{self, indices, values}, std::move(shapes), @@ -66,15 +66,15 @@ std::string IndexPut::ToString() const { return ss.str(); } -bool IndexPut::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, +bool IndexPut::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate) const { return false; } TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -95,5 +95,5 @@ TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function, return index_out; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/index.h b/projects/ltc/csrc/base_lazy_backend/ops/index.h index e97760fc37ad..6f63cbc686a6 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/index.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/index.h @@ -15,44 +15,44 @@ namespace torch { namespace lazy { class IndexTensor : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::aten::index); } - IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices, - std::vector&& shapes); + IndexTensor(const torch::lazy::Value &self, const torch::lazy::Value &indices, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; }; class IndexPut : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::aten::index_put); } - IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate, - std::vector&& shapes); + IndexPut(const torch::lazy::Value &self, const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; bool accumulate; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp index 0653e4467313..e3db5ca37608 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp @@ -15,7 +15,7 @@ namespace torch { namespace lazy { -IValueConstant::IValueConstant(const c10::IValue& value) +IValueConstant::IValueConstant(const c10::IValue &value) : torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{}, std::vector{}, /* num_outputs */ 1, torch::lazy::MHash()), @@ -28,9 +28,9 @@ std::string IValueConstant::ToString() const { } TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { return {loctx->graph()->insertConstant(value)}; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h index 8f488ff47336..48fb95b73ddd 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h @@ -18,20 +18,20 @@ namespace lazy { // parameter which is helpful in different usecases when we need custom // native ops lowering to torch-mlir IR nodes. class IValueConstant : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::prim::Constant); } - IValueConstant(const c10::IValue& value); + IValueConstant(const c10::IValue &value); std::string ToString() const override; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; c10::IValue value; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/split.cpp b/projects/ltc/csrc/base_lazy_backend/ops/split.cpp index d20d298dfdd0..91cbd2a52e3d 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/split.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/split.cpp @@ -13,10 +13,10 @@ namespace torch { namespace lazy { SplitWithSizesCopy::SplitWithSizesCopy( - const torch::lazy::Value& self, const ::std::vector& split_sizes, - const int64_t& dim, std::vector&& shapes) + const torch::lazy::Value &self, const ::std::vector &split_sizes, + const int64_t &dim, std::vector &&shapes) : torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(), - OpList{ self }, std::move(shapes), + OpList{self}, std::move(shapes), split_sizes.size() /* num_outputs */, torch::lazy::MHash(split_sizes, dim)), split_sizes(split_sizes), dim(dim) {} @@ -29,15 +29,15 @@ std::string SplitWithSizesCopy::ToString() const { return ss.str(); } -bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim) const { +bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim) const { return false; } TorchMlirOpVector SplitWithSizesCopy::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -55,13 +55,13 @@ SplitWithSizesCopy::Lower(TorchMlirFunction function, return split_with_sizes_copy_out; } -SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim, - std::vector&& shapes, +SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim, + std::vector &&shapes, const size_t num_outputs) : torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(), - OpList{ self, split_size }, std::move(shapes), + OpList{self, split_size}, std::move(shapes), num_outputs, torch::lazy::MHash(dim)), dim(dim) {} @@ -72,15 +72,15 @@ std::string SplitCopyTensor::ToString() const { return ss.str(); } -bool SplitCopyTensor::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim) const { +bool SplitCopyTensor::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim) const { return false; } TorchMlirOpVector SplitCopyTensor::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/split.h b/projects/ltc/csrc/base_lazy_backend/ops/split.h index 8593d5628c2e..116ddd64ab2b 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/split.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/split.h @@ -20,19 +20,19 @@ class SplitWithSizesCopy : public torch::lazy::TorchMlirNode { return torch::lazy::OpKind(at::aten::split_with_sizes_copy); } - SplitWithSizesCopy(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim, - std::vector&& shapes); + SplitWithSizesCopy(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; std::vector split_sizes; int64_t dim; @@ -44,19 +44,19 @@ class SplitCopyTensor : public torch::lazy::TorchMlirNode { return torch::lazy::OpKind(at::aten::split_copy); } - SplitCopyTensor(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, const int64_t& dim, - std::vector&& shapes, + SplitCopyTensor(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, const int64_t &dim, + std::vector &&shapes, const size_t num_outputs = 1); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; int64_t dim; }; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h b/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h index c6b75baaf8f3..402355031474 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h @@ -17,61 +17,65 @@ namespace torch { namespace lazy { - -// This IR was copied from code-generated output, but the entire _to_copy operator -// cannot be trivially code genereated since it is only desirable to capture IR for -// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke -// the aten/eager fallback necessitating directly implementing the right to(device) behavior +// This IR was copied from code-generated output, but the entire _to_copy +// operator cannot be trivially code genereated since it is only desirable to +// capture IR for certain permutaions of _to_copy (e.g. dtype), and for the +// others it is difficult to even invoke the aten/eager fallback necessitating +// directly implementing the right to(device) behavior class ToCopy : public torch::lazy::TorchMlirNode { - public: - ToCopy(const torch::lazy::Value& self, const c10::optional& dtype, const c10::optional& layout, const c10::optional& device, const c10::optional& pin_memory, const bool& non_blocking, const c10::optional& memory_format, std::vector&& shapes) - : torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy), - {self}, std::move(shapes), - /* num_outputs */ 1, - torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)), +public: + ToCopy(const torch::lazy::Value &self, + const c10::optional &dtype, + const c10::optional &layout, + const c10::optional &device, + const c10::optional &pin_memory, const bool &non_blocking, + const c10::optional &memory_format, + std::vector &&shapes) + : torch::lazy::TorchMlirNode( + torch::lazy::OpKind(at::aten::_to_copy), {self}, std::move(shapes), + /* num_outputs */ 1, + torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, + memory_format)), - dtype(dtype), - layout(layout), - device(device), - pin_memory(pin_memory), - non_blocking(non_blocking), - memory_format(memory_format) {} + dtype(dtype), layout(layout), device(device), pin_memory(pin_memory), + non_blocking(non_blocking), memory_format(memory_format) {} std::string ToString() const override { std::stringstream ss; ss << torch::lazy::TorchMlirNode::ToString(); if (dtype.has_value()) { - ss << ", dtype=" << dtype.value(); + ss << ", dtype=" << dtype.value(); } else { - ss << ", dtype=null"; + ss << ", dtype=null"; } if (layout.has_value()) { - ss << ", layout=" << layout.value(); + ss << ", layout=" << layout.value(); } else { - ss << ", layout=null"; + ss << ", layout=null"; } if (device.has_value()) { - ss << ", device=" << device.value(); + ss << ", device=" << device.value(); } else { - ss << ", device=null"; + ss << ", device=null"; } if (pin_memory.has_value()) { - ss << ", pin_memory=" << pin_memory.value(); + ss << ", pin_memory=" << pin_memory.value(); } else { - ss << ", pin_memory=null"; + ss << ", pin_memory=null"; } ss << ", non_blocking=" << non_blocking; if (memory_format.has_value()) { - ss << ", memory_format=" << memory_format.value(); + ss << ", memory_format=" << memory_format.value(); } else { - ss << ", memory_format=null"; + ss << ", memory_format=null"; } return ss.str(); } - torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function, - torch::lazy::TorchMlirLoweringContext* loctx) const override { - std::vector arguments; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + torch::lazy::TorchMlirLoweringContext *loctx) const override { + std::vector arguments; std::vector kwarguments; arguments.reserve(1); kwarguments.reserve(6); @@ -83,11 +87,12 @@ class ToCopy : public torch::lazy::TorchMlirNode { kwarguments.emplace_back("pin_memory", pin_memory); kwarguments.emplace_back("non_blocking", non_blocking); kwarguments.emplace_back("memory_format", memory_format); - torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); + torch::lazy::TorchMlirOpVector _to_copy_out = + torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), + arguments, kwarguments); TORCH_CHECK_EQ(_to_copy_out.size(), 1); return _to_copy_out; - } c10::optional dtype; @@ -97,5 +102,5 @@ class ToCopy : public torch::lazy::TorchMlirNode { bool non_blocking; c10::optional memory_format; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp index a5526366cd2b..c43c84d24d5e 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp @@ -12,9 +12,9 @@ namespace torch { namespace lazy { -UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, - std::vector&& shapes) - : torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{ self }, +UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim, + std::vector &&shapes) + : torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{self}, std::move(shapes), self.shape().size(dim), /* num_outputs */ torch::lazy::MHash(dim)), @@ -27,13 +27,13 @@ std::string UnbindCopyInt::ToString() const { return ss.str(); } -bool UnbindCopyInt::CanBeReused(const torch::lazy::Value& self, - const int64_t& dim) const { +bool UnbindCopyInt::CanBeReused(const torch::lazy::Value &self, + const int64_t &dim) const { return false; } TorchMlirOpVector UnbindCopyInt::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h index 766752c16517..9d6d83842b10 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h @@ -20,15 +20,15 @@ class UnbindCopyInt : public torch::lazy::TorchMlirNode { return torch::lazy::OpKind(at::aten::unbind_copy); } - UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, - std::vector&& shapes); + UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; int64_t dim; }; diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index d5458f9c4ea6..8e3b2c0702d3 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -21,28 +21,82 @@ namespace lazy { // TODO(henrytu): Upstream these shape inference functions to PyTorch in the // future. -std::vector compute_shape_add(const at::Tensor& self, - const at::Scalar& other, - const at::Scalar& alpha) { +std::vector compute_shape_add(const at::Tensor &self, + const at::Scalar &other, + const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } - -std::vector compute_shape_sub(const at::Tensor& self, - const at::Scalar& other, - const at::Scalar& alpha) { +std::vector compute_shape_sub(const at::Tensor &self, + const at::Scalar &other, + const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_div(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_div(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector +compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self, + const at::Tensor &scale, + const at::Tensor &zero_point, + int64_t axis) { + if (self.scalar_type() == at::kChar) + return {Shape(at::kQInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kByte) + return {Shape(at::kQUInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kInt) + return {Shape(at::kQInt32, self.sizes().vec())}; + assert(false); +} + +std::vector compute_shape__make_per_tensor_quantized_tensor( + const at::Tensor &self, double scale, int64_t zero_point) { + if (self.scalar_type() == at::kChar) + return {Shape(at::kQInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kByte) + return {Shape(at::kQUInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kInt) + return {Shape(at::kQInt32, self.sizes().vec())}; + assert(false); +} + +std::vector compute_shape_int_repr(const at::Tensor &self) { + if (self.scalar_type() == at::kQInt8) + return {Shape(at::kChar, self.sizes().vec())}; + if (self.scalar_type() == at::kQUInt8) + return {Shape(at::kByte, self.sizes().vec())}; + if (self.scalar_type() == at::kQInt32) + return {Shape(at::kInt, self.sizes().vec())}; + assert(false); +} + +std::vector +compute_shape_dequantize(const at::Tensor &self) { + return {Shape(at::kFloat, self.sizes().vec())}; +} + +std::vector +compute_shape_quantize_per_tensor(const at::Tensor &self, double scale, + int64_t zero_point, at::ScalarType dtype) { + return {Shape(dtype, self.sizes().vec())}; +} + +std::vector compute_shape_isinf(const at::Tensor &self) { + return {Shape(at::kBool, self.sizes().vec())}; +} + +std::vector compute_shape_quantize_per_channel( + const at::Tensor &self, const at::Tensor &scales, + const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) { + return {Shape(dtype, self.sizes().vec())}; +} + std::vector compute_shape_max_pool3d_with_indices( - const at::Tensor& self, at::IntArrayRef kernel_size, - at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, - bool ceil_mode) { + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { auto in_sizes = self.sizes().vec(); std::vector dhw(3, 0); std::vector paddings = padding.vec(); @@ -50,18 +104,19 @@ std::vector compute_shape_max_pool3d_with_indices( std::vector dilations = dilation.vec(); std::vector strides = stride.vec(); TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ", - in_sizes); - TORCH_CHECK(kernel_size.size() == 3 && - stride.size() == 3 && - padding.size() == 3 && - dilation.size() == 3, "max_pool3d requires 3D operands, but got ", - kernel_size, stride, padding, dilation); + in_sizes); + TORCH_CHECK(kernel_size.size() == 3 && stride.size() == 3 && + padding.size() == 3 && dilation.size() == 3, + "max_pool3d requires 3D operands, but got ", kernel_size, stride, + padding, dilation); int64_t batch = in_sizes[0]; int64_t channel = in_sizes[1]; // NCDHW // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html - for (auto i = 0UL; i<3; ++i) { - double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] * - (ksizes[i] - 1) - 1) / (double)strides[i] + 1; + for (auto i = 0UL; i < 3; ++i) { + double out_size = (in_sizes[2 + i] + 2 * paddings[i] - + dilations[i] * (ksizes[i] - 1) - 1) / + (double)strides[i] + + 1; if (ceil_mode) dhw[i] = (int64_t)std::ceil(out_size); else @@ -73,46 +128,54 @@ std::vector compute_shape_max_pool3d_with_indices( } std::vector compute_shape_max_pool3d_with_indices_backward( - const at::Tensor & grad_output, const at::Tensor & self, + const at::Tensor &grad_output, const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, - const at::Tensor & indices) { + const at::Tensor &indices) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_mse_loss_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& target, int64_t reduction) { +std::vector +compute_shape_mse_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, int64_t reduction) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_mul(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_mul(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_var( - const at::Tensor& self, at::OptionalIntArrayRef dim, - const c10::optional & correction, bool keepdim) { +std::vector +compute_shape_var(const at::Tensor &self, at::OptionalIntArrayRef dim, + const c10::optional &correction, bool keepdim) { // Result of variance is scalar tensor. return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_hardtanh( - const at::Tensor& self, const at::Scalar& min_val, - const at::Scalar& max_val) { +std::vector +compute_shape_nan_to_num(const at::Tensor &self, c10::optional nan, + c10::optional posinf, + c10::optional neginf) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector +compute_shape_hardtanh(const at::Tensor &self, const at::Scalar &min_val, + const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_hardtanh_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Scalar& min_val, const at::Scalar& max_val) { + const at::Tensor &grad_output, const at::Tensor &self, + const at::Scalar &min_val, const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_where(const at::Tensor& condition, - const at::Tensor& self, - const at::Tensor& other) { +std::vector compute_shape_where(const at::Tensor &condition, + const at::Tensor &self, + const at::Tensor &other) { // There are cases like - // torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>, // !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>. @@ -139,32 +202,32 @@ std::vector compute_shape_where(const at::Tensor& condition, return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_bucketize( - const at::Tensor& self, const at::Tensor& boundaries, bool out_int32, - bool right) { +std::vector +compute_shape_bucketize(const at::Tensor &self, const at::Tensor &boundaries, + bool out_int32, bool right) { auto dtype = out_int32 ? at::kInt : at::kLong; return {Shape(dtype, self.sizes().vec())}; } -std::vector compute_shape_copy(const at::Tensor& self, - const at::Tensor& src, +std::vector compute_shape_copy(const at::Tensor &self, + const at::Tensor &src, bool non_blocking) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_floor_divide( - const at::Tensor& self, const at::Tensor& other) { +std::vector +compute_shape_floor_divide(const at::Tensor &self, const at::Tensor &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_fmod(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_fmod(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_native_group_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, + const at::Tensor &input, const c10::optional &weight, + const c10::optional &bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { TORCH_CHECK(input.sizes().size() >= 2, @@ -182,9 +245,10 @@ std::vector compute_shape_native_group_norm( return shapes; } -std::vector compute_shape_im2col( - const at::Tensor& self, at::IntArrayRef kernel_size, - at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { +std::vector +compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, + at::IntArrayRef stride) { auto self_meta = at::native::empty_strided_meta_symint( self.sym_sizes(), self.sym_strides(), @@ -198,8 +262,8 @@ std::vector compute_shape_im2col( } std::vector compute_shape_native_group_norm_backward( - const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, - const at::Tensor& rstd, const c10::optional& weight, int64_t N, + const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &rstd, const c10::optional &weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { TORCH_CHECK(input.sizes().size() >= 2, @@ -218,26 +282,55 @@ std::vector compute_shape_native_group_norm_backward( return shapes; } -std::vector compute_shape_remainder( - const at::Tensor& self, const at::Scalar& other) { +std::vector +compute_shape_remainder(const at::Tensor &self, const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_uniform( - const at::Tensor& self, double from, double to, - c10::optional generator) { +std::vector +compute_shape_reflection_pad2d(const at::Tensor &self, + at::IntArrayRef padding) { + std::vector paddings = padding.vec(); + std::vector in_sizes = self.sizes().vec(); + auto num_dims = in_sizes.size(); + + TORCH_CHECK(padding.size() == 4); + TORCH_CHECK(num_dims >= 2); + + auto vdim = num_dims - 2; + auto hdim = num_dims - 1; + auto padding_left = padding[0]; + auto padding_right = padding[1]; + auto padding_top = padding[2]; + auto padding_bottom = padding[3]; + TORCH_CHECK(padding_left < in_sizes[hdim]); + TORCH_CHECK(padding_right < in_sizes[hdim]); + TORCH_CHECK(padding_top < in_sizes[vdim]); + TORCH_CHECK(padding_bottom < in_sizes[vdim]); + + std::vector out_sizes(in_sizes); + out_sizes[hdim] += padding_left + padding_right; + out_sizes[vdim] += padding_top + padding_bottom; + + return {Shape(self.scalar_type(), out_sizes)}; +} + +std::vector +compute_shape_uniform(const at::Tensor &self, double from, double to, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_normal_functional( - const at::Tensor& self, double mean, double std, - c10::optional generator) { +std::vector +compute_shape_normal_functional(const at::Tensor &self, double mean, double std, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_multinomial( - const at::Tensor& self, int64_t num_samples, bool replacement, - c10::optional generator) { +std::vector +compute_shape_multinomial(const at::Tensor &self, int64_t num_samples, + bool replacement, + c10::optional generator) { // Input tensor can be either 1D or 2D. The last dim of output // should be 'num_samples'. So the output shape can be either // [num_samples] or [m, num_samples]. @@ -247,35 +340,38 @@ std::vector compute_shape_multinomial( return {Shape(at::kLong, ishape)}; } -std::vector compute_shape_eye( - int64_t n, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_eye(int64_t n, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_eye( - int64_t n, int64_t m, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_eye(int64_t n, int64_t m, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_arange( - const at::Scalar& end, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_arange(const at::Scalar &end, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_arange( - const at::Scalar& start, const at::Scalar& end, + const at::Scalar &start, const at::Scalar &end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta), @@ -284,7 +380,7 @@ std::vector compute_shape_arange( } std::vector compute_shape_arange( - const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, + const at::Scalar &start, const at::Scalar &end, const at::Scalar &step, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { auto out_meta = at::arange(start, end, step, dtype, layout, @@ -293,34 +389,37 @@ std::vector compute_shape_arange( } std::vector compute_shape_full( - at::IntArrayRef size, const at::Scalar& fill_value, + at::IntArrayRef size, const at::Scalar &fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_ones( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_ones(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_zeros( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_zeros(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_empty( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory, - c10::optional memory_format) { +std::vector +compute_shape_empty(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } @@ -333,20 +432,21 @@ std::vector compute_shape_empty_strided( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_fill(const at::Tensor& self, - const at::Scalar& value) { +std::vector compute_shape_fill(const at::Tensor &self, + const at::Scalar &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_fill(const at::Tensor& self, - const at::Tensor& value) { +std::vector compute_shape_fill(const at::Tensor &self, + const at::Tensor &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_randn( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_randn(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } @@ -367,36 +467,39 @@ std::vector compute_shape_randint( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_resize( - const at::Tensor & self, at::IntArrayRef size, - c10::optional memory_format) { +std::vector +compute_shape_resize(const at::Tensor &self, at::IntArrayRef size, + c10::optional memory_format) { return {Shape(self.scalar_type(), size.vec())}; } -std::vector compute_shape_bernoulli( - const at::Tensor& self, const at::Tensor &p, - c10::optional generator) { +std::vector +compute_shape_bernoulli(const at::Tensor &self, const at::Tensor &p, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_scalar_tensor( - const at::Scalar & s, c10::optional dtype, + const at::Scalar &s, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; } -std::vector compute_shape_roll( - const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { +std::vector compute_shape_roll(const at::Tensor &self, + at::IntArrayRef shifts, + at::IntArrayRef dims) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { - auto out_meta = - at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory); +std::vector compute_shape_linspace( + const at::Scalar &start, const at::Scalar &end, int64_t steps, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::linspace(start, end, steps, dtype, layout, + c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } - -} // namespace lazy -} // namespace torch \ No newline at end of file +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/tensor.cpp b/projects/ltc/csrc/base_lazy_backend/tensor.cpp index 82ae6cc27f4a..5be4ab369ff1 100644 --- a/projects/ltc/csrc/base_lazy_backend/tensor.cpp +++ b/projects/ltc/csrc/base_lazy_backend/tensor.cpp @@ -14,16 +14,16 @@ namespace torch { namespace lazy { -at::Tensor CreateFunctionalizedAtenFromLtcTensor( - const LazyTensorPtr& ltc_tensor) { +at::Tensor +CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor) { at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor); if (!c10::impl::tls_is_dispatch_key_excluded( - c10::DispatchKey::Functionalize) && + c10::DispatchKey::Functionalize) && !at::functionalization::impl::isFunctionalTensor(tensor)) { return at::functionalization::impl::to_functional_tensor(tensor); } return tensor; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/tensor.h b/projects/ltc/csrc/base_lazy_backend/tensor.h index 4e39dd095aa5..18e63ef68cd6 100644 --- a/projects/ltc/csrc/base_lazy_backend/tensor.h +++ b/projects/ltc/csrc/base_lazy_backend/tensor.h @@ -18,7 +18,8 @@ namespace lazy { // should have explicit tensor functinoalization. Otherwise we can get // unfanctionalized primitives or in the worst case if we apply inplace // operations to unfunctionalized tensor it won't be captured in LTC graph. -TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); +TORCH_API at::Tensor +CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor); } // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/exception.h b/projects/ltc/csrc/base_lazy_backend/utils/exception.h index 96510d830aef..533677ad86eb 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/exception.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/exception.h @@ -21,8 +21,8 @@ } #define UNIMPLEMENTED_FUNCTION_ERROR() \ - UNIMPLEMENTED_ERROR( \ - "\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__) + UNIMPLEMENTED_ERROR("\n\t" << __FILE__ << ":" << __LINE__ << " " \ + << __PRETTY_FUNCTION__) #define UNSUPPORTED_ERROR(msg) \ { \ diff --git a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp index 9ca8b666a42e..a4f3673715e5 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp +++ b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp @@ -7,9 +7,9 @@ namespace torch { namespace jit { -void ConvertScalarImplicit(std::shared_ptr& graph) { +void ConvertScalarImplicit(std::shared_ptr &graph) { DepthFirstGraphNodeIterator it(graph); - for (auto* node = it.next(); node != nullptr; node = it.next()) { + for (auto *node = it.next(); node != nullptr; node = it.next()) { if (node->kind() != c10::aten::ScalarImplicit) { continue; } @@ -27,15 +27,13 @@ void ConvertScalarImplicit(std::shared_ptr& graph) { node_type = c10::aten::FloatImplicit; output_type = FloatType::get(); } else { - throw std::runtime_error( - "Expected isIntegralType or isFloatingType"); + throw std::runtime_error("Expected isIntegralType or isFloatingType"); } - Value * output = graph - ->create(node_type, {input}) - ->insertBefore(node) - ->output() - ->setType(output_type); + Value *output = graph->create(node_type, {input}) + ->insertBefore(node) + ->output() + ->setType(output_type); node->output()->replaceAllUsesWith(output); node->destroy(); } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h index 2c4214cfc1ab..d9e47b464235 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h @@ -4,7 +4,7 @@ namespace torch { namespace jit { // Convert ScalarImplicit to IntImplicit or FloatImplicit. -TORCH_API void ConvertScalarImplicit(std::shared_ptr& graph); +TORCH_API void ConvertScalarImplicit(std::shared_ptr &graph); } // namespace jit } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h index 281331992e49..a5a524b05353 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h @@ -1,49 +1,49 @@ #pragma once -#include #include +#include #include - template -std::ostream& string_join(std::ostream& out, const std::vector& v, const std::string& delimiter) { - size_t i = 0; - for (const T& e : v) { - if ((i++) > 0) { out << delimiter; } - out << e; +std::ostream &string_join(std::ostream &out, const std::vector &v, + const std::string &delimiter) { + size_t i = 0; + for (const T &e : v) { + if ((i++) > 0) { + out << delimiter; } - return out; + out << e; + } + return out; } template -std::string string_join(const std::vector& v, const std::string& delimiter) { - std::ostringstream joined; - string_join(joined, v, delimiter); - return joined.str(); +std::string string_join(const std::vector &v, const std::string &delimiter) { + std::ostringstream joined; + string_join(joined, v, delimiter); + return joined.str(); } -inline std::vector string_split( - const std::string& str, - const std::string& sep -) { - std::vector tokens; - std::size_t pos1 = str.find_first_not_of(sep); - while (pos1 != std::string::npos) { - std::size_t pos2 = str.find_first_of(sep, pos1); - if (pos2 == std::string::npos) { - tokens.push_back(str.substr(pos1)); - pos1 = pos2; - } else { - tokens.push_back(str.substr(pos1, pos2 - pos1)); - pos1 = str.find_first_not_of(sep, pos2 + 1); - } +inline std::vector string_split(const std::string &str, + const std::string &sep) { + std::vector tokens; + std::size_t pos1 = str.find_first_not_of(sep); + while (pos1 != std::string::npos) { + std::size_t pos2 = str.find_first_of(sep, pos1); + if (pos2 == std::string::npos) { + tokens.push_back(str.substr(pos1)); + pos1 = pos2; + } else { + tokens.push_back(str.substr(pos1, pos2 - pos1)); + pos1 = str.find_first_not_of(sep, pos2 + 1); } - return tokens; + } + return tokens; } /* * Returns true if str starts with prefix */ -inline bool startswith(const std::string& str, const std::string& prefix) { - return str.rfind(prefix, 0) == 0; +inline bool startswith(const std::string &str, const std::string &prefix) { + return str.rfind(prefix, 0) == 0; } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h index 5ae14904909a..5804bce5fd93 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h @@ -6,24 +6,25 @@ namespace sys_util { template -static T GetEnv(const std::string& name, const T& default_value = T(0)) { - const char* env = std::getenv(name.c_str()); +static T GetEnv(const std::string &name, const T &default_value = T(0)) { + const char *env = std::getenv(name.c_str()); if (!env) { return default_value; } return T(std::atoi(env)); } -static std::string GetEnvString(const std::string& name, const std::string& default_value) { - const char* env = std::getenv(name.c_str()); +[[maybe_unused]] static std::string +GetEnvString(const std::string &name, const std::string &default_value) { + const char *env = std::getenv(name.c_str()); if (!env) { return default_value; } return std::string(env); } -static bool GetEnvBool(const char* name, bool defval) { - const char* env = std::getenv(name); +[[maybe_unused]] static bool GetEnvBool(const char *name, bool defval) { + const char *env = std::getenv(name); if (env == nullptr) { return defval; } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp index cdd97168031b..71a0e89f4c64 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp +++ b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp @@ -3,84 +3,90 @@ #include "../generated/LazyIr.h" #include "../mlir_node.h" - namespace torch { namespace lazy { -bool is_detach_copy(const torch::lazy::Node* node) { - return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); +bool is_detach_copy(const torch::lazy::Node *node) { + return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); } -bool is_detach_copy(const torch::lazy::Value& value) { - return is_detach_copy(value.node.get()); +bool is_detach_copy(const torch::lazy::Value &value) { + return is_detach_copy(value.node.get()); } -torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) { - if (!node) { return nullptr; } +torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *node) { + if (!node) { + return nullptr; + } - torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); - while(mlir_node && is_detach_copy(mlir_node)) { - mlir_node = mlir_node->mlir_node(0); - } - if (!mlir_node) { - return node; - } - return mlir_node; + torch::lazy::TorchMlirNode *mlir_node = + dynamic_cast(node); + while (mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; } -const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) { - if (!node) { return nullptr; } +const torch::lazy::Node * +extract_non_detach_copy_node(const torch::lazy::Node *node) { + if (!node) { + return nullptr; + } - const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); - while(mlir_node && is_detach_copy(mlir_node)) { - mlir_node = mlir_node->mlir_node(0); - } - if (!mlir_node) { - return node; - } - return mlir_node; + const torch::lazy::TorchMlirNode *mlir_node = + dynamic_cast(node); + while (mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; } - -torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) { - if (!node) { - return nullptr; - } - node = extract_non_detach_copy_node(node); - if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } +torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *node) { + if (!node) { return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; } -const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) { - if (!node) { - return nullptr; - } - node = extract_non_detach_copy_node(node); - if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } +const torch::lazy::DeviceData *device_data_cast(const torch::lazy::Node *node) { + if (!node) { return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; } -torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { - if (!value) { - return nullptr; - } - return device_data_cast(value.node.get()); +torch::lazy::DeviceData *device_data_cast(const torch::lazy::Value &value) { + if (!value) { + return nullptr; + } + return device_data_cast(value.node.get()); } -torch::lazy::DeviceData* device_data_cast( - const at::Tensor& tensor, c10::optional device -) { - if (!device) { - device = torch::lazy::GetBackendDevice(tensor); - } - TORCH_CHECK(device); - torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); - if (lazy_tensor) { - return device_data_cast(lazy_tensor->GetIrValue()); - } - return nullptr; +torch::lazy::DeviceData * +device_data_cast(const at::Tensor &tensor, + c10::optional device) { + if (!device) { + device = torch::lazy::GetBackendDevice(tensor); + } + TORCH_CHECK(device); + torch::lazy::LazyTensorPtr lazy_tensor = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); + if (lazy_tensor) { + return device_data_cast(lazy_tensor->GetIrValue()); + } + return nullptr; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h index 745be78c35d2..f8e5e317294a 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h @@ -8,18 +8,21 @@ namespace torch { namespace lazy { -TORCH_API bool is_detach_copy(const torch::lazy::Node*); -TORCH_API bool is_detach_copy(const torch::lazy::Value&); +TORCH_API bool is_detach_copy(const torch::lazy::Node *); +TORCH_API bool is_detach_copy(const torch::lazy::Value &); -TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*); -TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*); +TORCH_API torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *); +TORCH_API const torch::lazy::Node * +extract_non_detach_copy_node(const torch::lazy::Node *); -TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*); -TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*); -TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); -TORCH_API torch::lazy::DeviceData* device_data_cast( - const at::Tensor& tensor, c10::optional device = c10::nullopt -); +TORCH_API torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *); +TORCH_API const torch::lazy::DeviceData * +device_data_cast(const torch::lazy::Node *); +TORCH_API torch::lazy::DeviceData * +device_data_cast(const torch::lazy::Value &value); +TORCH_API torch::lazy::DeviceData *device_data_cast( + const at::Tensor &tensor, + c10::optional device = c10::nullopt); -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/onnx_c_importer/CMakeLists.txt b/projects/onnx_c_importer/CMakeLists.txt new file mode 100644 index 000000000000..681ca14feafc --- /dev/null +++ b/projects/onnx_c_importer/CMakeLists.txt @@ -0,0 +1,34 @@ +message(STATUS "Enabling onnx_c_importer...") + +include(FetchContent) + +find_package(Protobuf REQUIRED CONFIG) + +option(ONNX_DISABLE_EXCEPTIONS "For compatibility with LLVM build" ON) + +FetchContent_Declare( + onnx + EXCLUDE_FROM_ALL + GIT_REPOSITORY https://github.com/onnx/onnx.git + GIT_TAG v1.15.0 + GIT_SHALLOW ON + GIT_PROGRESS ON +) +FetchContent_MakeAvailable(onnx) + +add_llvm_executable( + torch-mlir-import-onnx + PARTIAL_SOURCES_INTENDED + + import-onnx-main.cpp + OnnxImporter.h + OnnxImporter.cpp +) + +target_link_libraries( + torch-mlir-import-onnx + LLVMSupport + MLIRCAPIIR + TorchMLIRCAPI + onnx +) diff --git a/projects/onnx_c_importer/OnnxImporter.cpp b/projects/onnx_c_importer/OnnxImporter.cpp new file mode 100644 index 000000000000..4a61a2800ca5 --- /dev/null +++ b/projects/onnx_c_importer/OnnxImporter.cpp @@ -0,0 +1,1009 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "OnnxImporter.h" + +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" + +#include +#include + +using namespace torch_mlir_onnx; + +namespace { + +std::string SanitizeNameAsIdentifier(std::string_view in) { + std::string out; + if (!in.empty() && !std::isalnum(in.front())) { + out.append("_"); + } + out.append(in); + for (char &c : out) { + if (c == ':' || c == '/') + c = '_'; + } + return out; +} + +template +void AppendDelimittedStrings(std::string &into, T &container) { + bool first = true; + for (auto &item : container) { + if (first) { + first = false; + } else { + into.append(", "); + } + into.append(item); + } +} + +inline MlirStringRef toMlirStringRef(const std::string_view &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +inline MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +inline MlirStringRef toMlirStringRef(const char *s) { + return mlirStringRefCreate(s, std::strlen(s)); +} + +inline MlirNamedAttribute toMlirNamedAttribute(const char *s, + MlirAttribute attr) { + MlirContext context = mlirAttributeGetContext(attr); + MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s)); + return mlirNamedAttributeGet(ident, attr); +} + +std::string getMlirAsm(MlirType t) { + std::string result; + mlirTypePrint( + t, + +[](MlirStringRef sr, void *userData) { + std::string *s = static_cast(userData); + s->append(sr.data, sr.length); + }, + static_cast(&result)); + return result; +} + +// C++ helpers to create operations. +void addToMlirOperationState(MlirOperationState &state, + MlirNamedAttribute namedAttr) { + mlirOperationStateAddAttributes(&state, 1, &namedAttr); +} + +void addToMlirOperationState( + MlirOperationState &state, + std::vector> &attrs) { + for (auto &p : attrs) { + addToMlirOperationState(state, + toMlirNamedAttribute(p.first.c_str(), p.second)); + } +} + +void addToMlirOperationState(MlirOperationState &state, MlirRegion region) { + mlirOperationStateAddOwnedRegions(&state, 1, ®ion); +} + +[[maybe_unused]] void addToMlirOperationState(MlirOperationState &state, + MlirValue value) { + mlirOperationStateAddOperands(&state, 1, &value); +} + +void addToMlirOperationState(MlirOperationState &state, + const std::vector &values) { + mlirOperationStateAddOperands(&state, values.size(), values.data()); +} + +void addToMlirOperationState(MlirOperationState &state, MlirType resultType) { + mlirOperationStateAddResults(&state, 1, &resultType); +} + +void addToMlirOperationState(MlirOperationState &state, + const std::vector &resultTypes) { + mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); +} + +[[maybe_unused]] void addToMlirOperationState(MlirOperationState &state) {} + +template +void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u, + Ts &&...ts) { + addToMlirOperationState(state, std::forward(t)); + addToMlirOperationState(state, std::forward(u), std::forward(ts)...); +} + +template +MlirOperation createMlirOperation(std::string name, MlirLocation loc, + Ts &&...ts) { + MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc); + addToMlirOperationState(state, std::forward(ts)...); + return mlirOperationCreate(&state); +} + +template +MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name, + MlirLocation loc, Ts &&...ts) { + MlirOperation operation = + createMlirOperation(name, loc, std::forward(ts)...); + mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block), + operation); + return operation; +} + +} // namespace + +// ---------------------------------------------------------------------------// +// ModelInfo +// ---------------------------------------------------------------------------// + +ModelInfo::ModelInfo() = default; + +void ModelInfo::DebugDumpProto() { + std::string debug_string = model_proto_.DebugString(); + fprintf(stderr, "%s\n", debug_string.c_str()); +} + +Status ModelInfo::Initialize() { + if (!model_proto_.has_graph()) { + return SetError("ONNX ModelProto has no main graph"); + } + main_graph_ = std::make_unique(*this, model_proto_.graph()); + if (failed(main_graph_->Initialize())) { + return failure(); + } + + return success(); +} + +// ---------------------------------------------------------------------------// +// GraphInfo +// ---------------------------------------------------------------------------// + +Status GraphInfo::Initialize() { + // Initialize look up tables. + for (const onnx::TensorProto &t : graph_proto_.initializer()) { + initializer_map_.emplace(t.name(), t); + } + for (const onnx::ValueInfoProto &v : graph_proto_.value_info()) { + value_info_map_.emplace(v.name(), v); + } + for (const onnx::ValueInfoProto &v : graph_proto_.input()) { + declared_inputs_.emplace_back(&v); + } + for (const onnx::ValueInfoProto &v : graph_proto_.output()) { + outputs_.emplace_back(&v); + } + + // Generate the effective input map, which for old models can be a subset of + // the input map. + if (model_info_.config().elide_initialized_inputs) { + // Default. Add declared inputs to the input map unless if they appear + // as an initializer. + for (const onnx::ValueInfoProto *it : declared_inputs_) { + std::string_view key = it->name(); + if (initializer_map_.find(key) != initializer_map_.end()) { + // In initializers. Skip. + continue; + } + inputs_.emplace_back(it); + } + } else { + // Fallback for some legacy compatibility. + inputs_ = declared_inputs_; + std::vector illegal_keys; + for (const onnx::ValueInfoProto *it : inputs_) { + std::string_view key = it->name(); + if (initializer_map_.find(key) != initializer_map_.end()) { + illegal_keys.push_back(key); + } + } + if (!illegal_keys.empty()) { + std::string error = "When not in elide_initialized_inputs=true mode, we " + "expect inputs to not have an initial value (got "; + AppendDelimittedStrings(error, illegal_keys); + error.append(")"); + return model_info_.SetError(std::move(error)); + } + } + + // Index the inputs and outputs. + for (auto *input : inputs_) { + input_map_.emplace(input->name(), *input); + } + for (auto *output : outputs_) { + output_map_.emplace(output->name(), *output); + } + return success(); +} + +const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) { + // Node outputs don't typically have type information, but shape inference + // will associate them in the value_info. If not there, it may be a + // graph output, which must have type information. + { + auto it = value_info_map_.find(name); + if (it != value_info_map_.end()) { + return &it->second.type(); + } + } + { + auto it = output_map_.find(name); + if (it != output_map_.end()) { + return &it->second.type(); + } + } + + std::string msg = "No type information associated with '"; + msg.append(name); + msg.append("'. Run shape inference?"); + model_info_.SetError(std::move(msg)); + return nullptr; +} + +// ---------------------------------------------------------------------------// +// ContextCache +// ---------------------------------------------------------------------------// + +MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) { + if (tp.has_tensor_type()) { + // Convert Tensor TypeProto. + const onnx::TypeProto_Tensor &tt = tp.tensor_type(); + if (!tt.has_shape()) { + std::string msg = + "Unsupported Tensor type without shape (run shape inference?): "; + msg.append(tt.DebugString()); + model_info_.SetError(std::move(msg)); + return {nullptr}; + } + + MlirType element_type = ConvertTensorElementType(tt.elem_type()); + if (mlirTypeIsNull(element_type)) { + return {nullptr}; + } + shared_dims_.clear(); + shared_dims_.reserve(6); + for (const onnx::TensorShapeProto::Dimension &dim : tt.shape().dim()) { + if (dim.has_dim_value()) { + // Static. + shared_dims_.push_back(dim.dim_value()); + } else { + // Dynamic. + shared_dims_.push_back(-1); + } + } + + return GetVtensorType(shared_dims_, element_type); + } else { + std::string msg = "Unsupported ONNX TypeProto: "; + msg.append(tp.DebugString()); + model_info_.SetError(std::move(msg)); + return {nullptr}; + } +} + +MlirType ContextCache::ConvertTensorElementType(int elem_type) { + auto it = elem_type_map_.find(elem_type); + if (it != elem_type_map_.end()) { + return it->second; + } + + MlirType t = {nullptr}; + switch (elem_type) { + case onnx::TensorProto::FLOAT: + t = mlirF32TypeGet(context_); + break; + case onnx::TensorProto::UINT8: + t = mlirIntegerTypeUnsignedGet(context_, 8); + break; + case onnx::TensorProto::INT8: + t = mlirIntegerTypeSignedGet(context_, 8); + break; + case onnx::TensorProto::UINT16: + t = mlirIntegerTypeUnsignedGet(context_, 16); + break; + case onnx::TensorProto::INT16: + t = mlirIntegerTypeSignedGet(context_, 16); + break; + case onnx::TensorProto::INT32: + t = mlirIntegerTypeSignedGet(context_, 32); + break; + case onnx::TensorProto::UINT32: + t = mlirIntegerTypeUnsignedGet(context_, 32); + break; + case onnx::TensorProto::INT64: + t = mlirIntegerTypeSignedGet(context_, 64); + break; + case onnx::TensorProto::UINT64: + t = mlirIntegerTypeUnsignedGet(context_, 64); + break; + case onnx::TensorProto::BOOL: + t = mlirIntegerTypeGet(context_, 1); + break; + case onnx::TensorProto::FLOAT16: + t = mlirF16TypeGet(context_); + break; + case onnx::TensorProto::DOUBLE: + t = mlirF64TypeGet(context_); + break; + case onnx::TensorProto::COMPLEX64: + t = mlirComplexTypeGet(mlirF32TypeGet(context_)); + break; + case onnx::TensorProto::COMPLEX128: + t = mlirComplexTypeGet(mlirF64TypeGet(context_)); + break; + case onnx::TensorProto::BFLOAT16: + t = mlirBF16TypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E4M3FN: + t = mlirFloat8E4M3FNTypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E4M3FNUZ: + t = mlirFloat8E4M3FNUZTypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E5M2: + t = mlirFloat8E5M2TypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E5M2FNUZ: + t = mlirFloat8E5M2FNUZTypeGet(context_); + break; + default: { + std::string msg = "Unknown ONNX tensor element type: "; + msg.append(std::to_string(elem_type)); + model_info_.SetError(std::move(msg)); + return {nullptr}; + } + } + + assert(t.ptr && "did not convert type"); + elem_type_map_[elem_type] = t; + return t; +} + +MlirAttribute +ContextCache::ConvertTensorProtoToAttr(const onnx::TensorProto &tp) { + MlirType tensor_type = ConvertTensorProtoToBuiltinType(tp); + if (tp.has_raw_data()) { + std::string sanitized_name = SanitizeNameAsIdentifier(tp.name()); + // Conveniently, DenseResourceElementsAttr shares the raw data + // format. We just give it maximum numeric alignment. + return mlirUnmanagedDenseResourceElementsAttrGet( + tensor_type, toMlirStringRef(sanitized_name), + const_cast(static_cast(tp.raw_data().data())), + tp.raw_data().size(), /*dataAlignment=*/8, /*dataIsMutable=*/false, + /*deleter=*/nullptr, /*userData=*/nullptr); + } else { + switch (tp.data_type()) { + case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: + return mlirDenseElementsAttrFloatGet(tensor_type, tp.float_data_size(), + tp.float_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: + return mlirDenseElementsAttrInt32Get(tensor_type, tp.int32_data_size(), + tp.int32_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: + return mlirDenseElementsAttrInt64Get(tensor_type, tp.int64_data_size(), + tp.int64_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: + return mlirDenseElementsAttrDoubleGet(tensor_type, tp.double_data_size(), + tp.double_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { + // Special case. See proto. Someone apparently got lazy. + std::vector stupid_conversion; + stupid_conversion.reserve(tp.uint64_data_size()); + for (uint64_t v : tp.uint64_data()) + stupid_conversion.push_back(v); + return mlirDenseElementsAttrUInt32Get( + tensor_type, stupid_conversion.size(), stupid_conversion.data()); + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: + return mlirDenseElementsAttrUInt64Get(tensor_type, tp.uint64_data_size(), + tp.uint64_data().data()); + } + } + + std::string message = + "Unable to convert ONNX TensorProto to MLIR attribute: "; + message.append(tp.DebugString()); + model_info_.SetError(std::move(message)); + return {nullptr}; +} + +MlirType +ContextCache::ConvertTensorProtoToBuiltinType(const onnx::TensorProto &tp) { + MlirType element_type = ConvertTensorElementType(tp.data_type()); + if (mlirTypeIsNull(element_type)) + return {nullptr}; + + shared_dims_.clear(); + for (auto dim : tp.dims()) { + shared_dims_.push_back(dim); + } + return mlirRankedTensorTypeGet(shared_dims_.size(), shared_dims_.data(), + element_type, + /*encoding=*/{nullptr}); +} + +MlirType +ContextCache::ConvertTensorProtoToVtensorType(const onnx::TensorProto &tp) { + MlirType element_type = ConvertTensorElementType(tp.data_type()); + if (mlirTypeIsNull(element_type)) + return {nullptr}; + + shared_dims_.clear(); + for (auto dim : tp.dims()) { + shared_dims_.push_back(dim); + } + + return GetVtensorType(shared_dims_, element_type); +} + +MlirType ContextCache::GetVtensorType(const std::vector &dims, + MlirType element_type) { + std::string type_asm = "!torch.vtensor<["; + // Add dimension list. + bool first_dim = true; + for (int dim : dims) { + if (first_dim) + first_dim = false; + else + type_asm.push_back(','); + if (dim < 0) + type_asm.push_back('?'); + else + type_asm.append(std::to_string(dim)); + } + type_asm.append("],"); + + // Add element type. + type_asm.append(getMlirAsm(element_type)); + type_asm.push_back('>'); + + // Look in cache. + auto found_it = asm_type_map_.find(type_asm); + if (found_it != asm_type_map_.end()) { + return found_it->second; + } + + // Parse. + MlirType t = mlirTypeParseGet(context_, toMlirStringRef(type_asm)); + if (mlirTypeIsNull(t)) { + std::string message = + "internal error: could not parse !torch.vtensor type: "; + message.append(type_asm); + model_info_.SetError(std::move(message)); + return t; + } + + asm_type_map_[std::move(type_asm)] = t; + return t; +} + +// ---------------------------------------------------------------------------// +// NodeImporter +// ---------------------------------------------------------------------------// + +NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc, + MlirOperation module_op) + : graph_info_(graph_info), cc_(cc), + context_(mlirOperationGetContext(module_op)), module_op_(module_op), + func_op_({nullptr}), body_block_({nullptr}) { + std::string locName = "graph:"; + locName.append(graph_info.graph_proto().name()); + default_loc_ = mlirLocationNameGet(context_, toMlirStringRef(locName), + /*childLoc=*/{nullptr}); +} + +Status NodeImporter::DefineFunction(std::optional name) { + const onnx::GraphProto &p = graph_info_.graph_proto(); + MlirRegion moduleBodyRegion = mlirOperationGetRegion(module_op_, 0); + MlirBlock moduleBody = mlirRegionGetFirstBlock(moduleBodyRegion); + MlirAttribute nameAttr; + if (name) { + // Explicitly named. + nameAttr = mlirStringAttrGet(context_, toMlirStringRef(*name)); + } else { + // Name it according to the graph. + nameAttr = mlirStringAttrGet(context_, toMlirStringRef(p.name())); + } + + // Derive the FunctionType. + std::vector input_types; + std::vector input_locs; + std::vector output_types; + for (auto *input : graph_info_.inputs()) { + MlirType t = cc_.ConvertTypeProto(input->type()); + if (mlirTypeIsNull(t)) { + return failure(); + } + input_types.push_back(t); + input_locs.push_back(default_loc_); + } + for (auto *output : graph_info_.outputs()) { + MlirType t = cc_.ConvertTypeProto(output->type()); + if (mlirTypeIsNull(t)) { + return failure(); + } + output_types.push_back(t); + } + MlirType ftype = + mlirFunctionTypeGet(context_, input_types.size(), input_types.data(), + output_types.size(), output_types.data()); + + // Create func.func. + func_op_ = createMlirOperationAtEnd( + moduleBody, "func.func", default_loc_, mlirRegionCreate(), + toMlirNamedAttribute("function_type", mlirTypeAttrGet(ftype)), + toMlirNamedAttribute("sym_name", nameAttr)); + + // Add entry block. + body_block_ = mlirBlockCreate(input_types.size(), input_types.data(), + input_locs.data()); + MlirRegion bodyRegion = mlirOperationGetRegion(func_op_, 0); + mlirRegionAppendOwnedBlock(bodyRegion, body_block_); + + // Map the block args to names and store for evaluation. + for (int i = 0, e = graph_info_.inputs().size(); i < e; ++i) { + std::string_view name = graph_info_.inputs()[i]->name(); + MlirValue value = mlirBlockGetArgument(body_block_, i); + nv_map_[name] = value; + } + + PopulateGraphAttrs(func_op_); + return success(); +} + +void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) { + const onnx::ModelProto &m = graph_info_.model_info().model_proto(); + MlirType i64_type = mlirIntegerTypeSignedGet(context_, 64); + int default_opset_version = 0; + std::unordered_map opset_versions; + // Determine model level opset versions. + for (const onnx::OperatorSetIdProto &opset_import : m.opset_import()) { + if (opset_import.has_domain()) { + opset_versions[opset_import.domain()] = + mlirIntegerAttrGet(i64_type, opset_import.version()); + } else { + default_opset_version = opset_import.version(); + } + } + + // Set the default domain version. + if (default_opset_version != 0) { + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.opset_version"), + mlirIntegerAttrGet(i64_type, default_opset_version)); + } + + // Set versions for other domains. + if (!opset_versions.empty()) { + std::vector version_attrs; + for (auto it : opset_versions) { + version_attrs.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(context_, toMlirStringRef(it.first)), it.second)); + } + MlirAttribute dict_attr = mlirDictionaryAttrGet( + context_, version_attrs.size(), version_attrs.data()); + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.opset_versions"), + dict_attr); + } + + // IR version and producer. + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.ir_version"), + mlirIntegerAttrGet(i64_type, m.ir_version())); + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.producer_name"), + mlirStringAttrGet(context_, toMlirStringRef(m.producer_name()))); + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.producer_version"), + mlirStringAttrGet(context_, toMlirStringRef(m.producer_version()))); +} + +Status NodeImporter::ImportAll() { + // TODO: Consider pulling in initializers on demand since there can be so + // much unused crap. + for (auto it : graph_info_.initializer_map()) { + if (failed(ImportInitializer(it.second))) + return failure(); + } + for (auto it : graph_info_.graph_proto().node()) { + if (failed(ImportNode(it))) + return failure(); + } + + // Lookup the outputs, which should all be in the nv_map if the graph was + // properly formed. + std::vector output_values; + for (const auto *output : graph_info_.outputs()) { + std::string_view name = output->name(); + auto found_it = nv_map_.find(name); + if (found_it == nv_map_.end()) { + std::string msg = "Non topologically produced ONNX graph output '"; + msg.append(name); + msg.append("'"); + return SetError(std::move(msg)); + } + output_values.push_back(found_it->second); + } + + createMlirOperationAtEnd(body_block_, "func.return", default_loc_, + output_values); + return success(); +} + +Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) { + std::string_view name = initializer.name(); + MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(name), + /*childLoc=*/{nullptr}); + + MlirAttribute value_attr = cc_.ConvertTensorProtoToAttr(initializer); + MlirType vtensor_type = cc_.ConvertTensorProtoToVtensorType(initializer); + if (mlirAttributeIsNull(value_attr) || mlirTypeIsNull(vtensor_type)) + return failure(); + + MlirOperation op = createMlirOperationAtEnd( + body_block_, "torch.vtensor.literal", loc, vtensor_type, + toMlirNamedAttribute("value", value_attr)); + MlirValue result = mlirOperationGetResult(op, 0); + + auto inserted = nv_map_.insert(std::make_pair(name, result)); + if (!inserted.second) { + std::string msg = "Multiple nodes produced a value for '"; + msg.append(name); + msg.append("', most recent from "); + msg.append(initializer.DebugString()); + return SetError(std::move(msg)); + } + + return success(); +} + +Status NodeImporter::ImportNode(const onnx::NodeProto &node) { + std::string_view op_type = node.op_type(); + // Handle special-form op types that do not go down the generic path. + if (op_type == "ConstantOfShape") { + return ImportConstantOfShapeNode(node); + } + + return ImportGeneralNode(node); +} + +Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { + MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(node.name()), + /*childLoc=*/{nullptr}); + + // Map inputs to values. + std::vector input_values; + for (auto &input_name : node.input()) { + auto found_it = nv_map_.find(input_name); + if (found_it == nv_map_.end()) { + std::string msg = "Non topologically produced ONNX node input '"; + msg.append(input_name); + msg.append("'"); + return SetError(std::move(msg)); + } + input_values.push_back(found_it->second); + } + + // Map outputs to types. + std::vector output_types; + for (auto &output_name : node.output()) { + const onnx::TypeProto *type_proto = + graph_info_.FindTypeProtoForName(output_name); + if (!type_proto) + return failure(); + + MlirType t = cc_.ConvertTypeProto(*type_proto); + if (mlirTypeIsNull(t)) + return failure(); + output_types.push_back(t); + } + + // Derive the op name. + std::string op_name = "onnx."; + op_name.append(node.op_type()); + MlirAttribute op_name_attr = + mlirStringAttrGet(context_, toMlirStringRef(op_name)); + + // General attributes. + std::vector> general_attributes; + for (auto &onnx_attr : node.attribute()) { + MlirAttribute attr = ImportGeneralAttribute(onnx_attr); + if (mlirAttributeIsNull(attr)) + return failure(); + std::string full_name = "torch.onnx."; + full_name.append(onnx_attr.name()); + general_attributes.push_back(std::make_pair(full_name, attr)); + } + + // Create op. + MlirOperation op = createMlirOperationAtEnd( + body_block_, "torch.operator", loc, output_types, input_values, + toMlirNamedAttribute("name", op_name_attr), general_attributes); + + // Record the result values. + for (int i = 0, e = output_types.size(); i < e; ++i) { + MlirValue result = mlirOperationGetResult(op, i); + std::string_view name = node.output(i); + auto inserted = nv_map_.insert(std::make_pair(name, result)); + if (!inserted.second) { + std::string msg = "Multiple nodes produced a value for '"; + msg.append(name); + msg.append("', most recent from "); + msg.append(node.DebugString()); + return SetError(std::move(msg)); + } + } + + return success(); +} + +MlirAttribute +NodeImporter::ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr) { + switch (onnx_attr.type()) { + case onnx::AttributeProto::UNDEFINED: + SetError("'UNDEFINED' attribute type not supported"); + return {nullptr}; + case onnx::AttributeProto::FLOAT: + return mlirFloatAttrDoubleGet(context_, mlirF32TypeGet(context_), + onnx_attr.f()); + case onnx::AttributeProto::INT: + return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(context_, 64), + onnx_attr.i()); + case onnx::AttributeProto::STRING: + return mlirStringAttrGet(context_, toMlirStringRef(onnx_attr.s())); + case onnx::AttributeProto::TENSOR: + return cc_.ConvertTensorProtoToAttr(onnx_attr.t()); + case onnx::AttributeProto::GRAPH: + SetError("'GRAPH' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::SPARSE_TENSOR: + SetError("'SPARSE_TENSOR' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::TYPE_PROTO: + SetError("'TYPE_PROTO' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::FLOATS: { + std::vector attrs; + for (auto f : onnx_attr.floats()) + attrs.push_back( + mlirFloatAttrDoubleGet(context_, mlirF32TypeGet(context_), f)); + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::INTS: { + std::vector attrs; + for (auto i : onnx_attr.ints()) + attrs.push_back( + mlirIntegerAttrGet(mlirIntegerTypeSignedGet(context_, 64), i)); + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::STRINGS: { + std::vector attrs; + for (auto s : onnx_attr.strings()) + attrs.push_back(mlirStringAttrGet(context_, toMlirStringRef(s))); + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::TENSORS: { + std::vector attrs; + for (auto &t : onnx_attr.tensors()) { + MlirAttribute attr = cc_.ConvertTensorProtoToAttr(t); + if (mlirAttributeIsNull(attr)) + return {nullptr}; + attrs.push_back(attr); + } + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::GRAPHS: + SetError("'GRAPHS' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::SPARSE_TENSORS: + SetError("'SPARSE_TENSORS' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::TYPE_PROTOS: + SetError("'TYPE_PROTOS' attribute type not supported on this node"); + return {nullptr}; + } + + std::string msg = "Unhandled ONNX attribute type code "; + msg.append(std::to_string(onnx_attr.type())); + msg.append(": "); + msg.append(onnx_attr.DebugString()); + SetError(std::move(msg)); + return {nullptr}; +} + +Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { + std::string_view name = node.name(); + MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(name), + /*childLoc=*/{nullptr}); + + // This op is special: It has an input of the shape, and in full generality + // could involve eager production of constants of variable size. In + // practice, the DNN profile for ONNX makes this very difficult to do + // and we hard-assert that the input can be resolved to an immediate + // value. + if (node.input_size() != 1 || node.output_size() != 1) { + return SetError("ConstantOfShape node must have one input and output"); + } + + // Shape. + std::vector shape; + if (failed(GetImmediateShapeTensor(node.input(0), shape))) + return failure(); + + // Value. + const onnx::AttributeProto *value_proto = nullptr; + for (auto &attr : node.attribute()) { + if (attr.name() == "value") { + value_proto = &attr; + break; + } + } + if (!value_proto) { + return SetError("ConstantOfShape node must have a 'value' attribute"); + } + if (value_proto->type() != onnx::AttributeProto_AttributeType_TENSOR) { + return SetError("ConstantOfShape node must have a tensor value attribute"); + } + + // Create the splat. + const onnx::TensorProto &tensor_proto = value_proto->t(); + if (tensor_proto.dims_size() != 1 || tensor_proto.dims(0) != 1) { + return SetError("ConstantOfShape node expected a scalar tensor value"); + } + auto tensorTypeFor = [&](MlirType element_type) { + return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type, + /*encoding*/ {nullptr}); + }; + MlirAttribute splat_attr = {nullptr}; + switch (tensor_proto.data_type()) { + case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)), + tensor_proto.int32_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)), + tensor_proto.int64_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)), + tensor_proto.uint64_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: + // Special case: inline data is stored in uint64. + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)), + tensor_proto.uint64_data(0)); + break; + } + + if (mlirAttributeIsNull(splat_attr)) { + std::string message = + "ConstantOfShape node has an unsupported splat data type: "; + message.append(tensor_proto.DebugString()); + return SetError(std::move(message)); + } + + // Create the vtensor type for the result. + MlirType splat_type = mlirAttributeGetType(splat_attr); + MlirType element_type = mlirShapedTypeGetElementType(splat_type); + MlirType vtensor_type = cc_.GetVtensorType(shape, element_type); + if (mlirTypeIsNull(vtensor_type)) + return failure(); + + MlirOperation op = createMlirOperationAtEnd( + body_block_, "torch.vtensor.literal", loc, vtensor_type, + toMlirNamedAttribute("value", splat_attr)); + MlirValue result = mlirOperationGetResult(op, 0); + + // Export to the nv_map. + auto inserted = nv_map_.insert(std::make_pair(name, result)); + if (!inserted.second) { + std::string msg = "Multiple nodes produced a value for '"; + msg.append(name); + msg.append("', most recent from "); + msg.append(node.DebugString()); + return SetError(std::move(msg)); + } + + return success(); +} + +Status NodeImporter::GetImmediateShapeTensor(const std::string &name, + std::vector &shape) { + auto found_it = graph_info_.initializer_map().find(name); + if (found_it == graph_info_.initializer_map().end()) { + std::string message = "An immediate shape value for '"; + message.append(name); + message.append("' was required but it is dynamically produced"); + return SetError(std::move(message)); + } + + const onnx::TensorProto &tp = found_it->second; + shape.clear(); + + // Since this is being interpreted as a shape, we only support some limited + // types. + size_t raw_data_size; + switch (tp.data_type()) { + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: { + auto *raw_data = graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + for (auto v : tp.int32_data()) + shape.push_back(v); + } + return success(); + } + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: { + auto *raw_data = graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + for (auto v : tp.int64_data()) + shape.push_back(v); + } + return success(); + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { + auto *raw_data = + graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + // Stupid special case: stored in uint64. + for (auto v : tp.uint64_data()) + shape.push_back(v); + } + return success(); + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: { + auto *raw_data = + graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + for (auto v : tp.uint64_data()) + shape.push_back(v); + } + return success(); + } + } + + { + std::string message = + "An immediate shape value could not be converted from TensorProto: "; + message.append(tp.DebugString()); + return SetError(std::move(message)); + } +} + +void NodeImporter::DebugDumpModule() { + auto callback = +[](MlirStringRef sr, void *) { + fwrite(sr.data, sizeof(char), sr.length, stderr); + }; + mlirOperationPrint(module_op_, callback, nullptr); +} diff --git a/projects/onnx_c_importer/OnnxImporter.h b/projects/onnx_c_importer/OnnxImporter.h new file mode 100644 index 000000000000..57070e0e5f2a --- /dev/null +++ b/projects/onnx_c_importer/OnnxImporter.h @@ -0,0 +1,240 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +// Stand-alone ONNX -> MLIR importer. +// This library only depends on ONNX (and transitively protobuf, of course) +// and the MLIR C API. It does this to minimize its dependency surface area +// and make it possible to integrate as source code into other systems while +// retaining this implementation as the source of truth. +// +// It uses a hybrid of LLVM and Google C++ coding style, preferring the latter +// for class members/accessors because canonical protobuf coding presumes +// this kind of style. + +#include "mlir-c/IR.h" +#include "onnx/onnx_pb.h" + +#include +#include +#include + +namespace torch_mlir_onnx { + +struct Config; +class GraphInfo; +class ModelInfo; + +struct Config { + // Ancient ONNX exporters would often add a model input for anything that + // might be mutable, providing an initializer for it as well. More modern + // tools tools realized this is a really bad idea for a lot of reasons. + // We choose to assume more recent norms, even if encountering older + // models. Setting this to False probably won't do what you want but + // should produce interesting errors to waste your time deciphering. + // We mainly use it as a way to document in the code that we are + // making an assumption. + bool elide_initialized_inputs = true; +}; + +/// A light-weight status. It only encapsulates success/failure. +/// Full error information will be set on the ModelInfo. +class Status { +public: + static Status success(bool isSuccess = true) { return Status(isSuccess); } + static Status failure(bool isFailure = true) { return Status(!isFailure); } + + bool is_success() { return is_success_; } + +private: + Status(bool is_success) : is_success_(is_success) {} + bool is_success_; +}; + +static inline Status success() { return Status::success(); } +static inline Status failure() { return Status::failure(); } +static inline bool succeeded(Status status) { return status.is_success(); } +static inline bool failed(Status status) { return !status.is_success(); } + +// Accounting for a GraphProto. +class GraphInfo { +public: + GraphInfo(ModelInfo &model_info, const onnx::GraphProto &graph_proto) + : model_info_(model_info), graph_proto_(graph_proto) {} + ModelInfo &model_info() { return model_info_; } + const onnx::GraphProto &graph_proto() { return graph_proto_; } + + /// Post-construction, failable initialization. + Status Initialize(); + + /// Finds a TypeProto for the given value name. If returning nullptr, then + /// an error will have been set. + const onnx::TypeProto *FindTypeProtoForName(std::string_view name); + + /// Attempts to access the raw or external data of the TensorProto. If the + /// the data is located in those positions, returns a types pointer to it + /// and stores the number of elements to `out_size`. Otherwise, nullptr is + /// returned (and no error is set). + template + const ElementType *GetOptionalRawData(const onnx::TensorProto &tp, + size_t &out_size) { + if (tp.has_raw_data()) { + out_size = tp.raw_data().size() / sizeof(ElementType); + return reinterpret_cast(tp.raw_data().data()); + } + return nullptr; + } + + std::vector &inputs() { return inputs_; } + std::unordered_map & + input_map() { + return input_map_; + } + std::vector &outputs() { return outputs_; } + std::unordered_map & + output_map() { + return output_map_; + } + + std::unordered_map & + initializer_map() { + return initializer_map_; + } + +private: + ModelInfo &model_info_; + const onnx::GraphProto &graph_proto_; + + std::unordered_map + initializer_map_; + std::unordered_map + value_info_map_; + + std::vector declared_inputs_; + std::vector inputs_; + std::vector outputs_; + std::unordered_map input_map_; + std::unordered_map + output_map_; +}; + +/// Top-level accounting and accessors for an ONNX model. +class ModelInfo { +public: + ModelInfo(); + Config &config() { return config_; } + onnx::ModelProto &model_proto() { return model_proto_; } + + /// Post-construction, failable initialization. + Status Initialize(); + + GraphInfo &main_graph() { return *main_graph_; } + const std::string &error_message() { return error_message_; } + + Status SetError(std::string msg) { + error_message_ = std::move(msg); + return failure(); + } + + void DebugDumpProto(); + +private: + Config config_; + onnx::ModelProto model_proto_; + std::unique_ptr main_graph_; + + std::string error_message_; +}; + +class ContextCache { +public: + ContextCache(ModelInfo &model_info, MlirContext context) + : model_info_(model_info), context_(context) {} + + MlirContext context() { return context_; } + + /// Converts the TypeProto to an MlirType, returning a null type and + /// setting an error if not possible. + MlirType ConvertTypeProto(const onnx::TypeProto &tp); + + /// Converts the ONNX element type code to an MlirType, returning a null type + /// and setting an error if not possible. + MlirType ConvertTensorElementType(int element_type_code); + + /// Converts an ONNX TensorProto to an MlirAttribute, returning a null + /// attribute and setting an error if not possible. + MlirAttribute ConvertTensorProtoToAttr(const onnx::TensorProto &tp); + + /// Converts the ONNX TensorProto to an Mlir RankedTensor type. + MlirType ConvertTensorProtoToBuiltinType(const onnx::TensorProto &tp); + + /// Converts the ONNX TensorProto to a !torch.vtensor type. + MlirType ConvertTensorProtoToVtensorType(const onnx::TensorProto &tp); + + /// Gets a !torch.vtensor type for the given dims and element type. + /// Dynamic dims are represented as -1. + /// If it was not possible to create the type, sets an error and returns + /// the null type. + MlirType GetVtensorType(const std::vector &dims, + MlirType element_type); + +private: + ModelInfo &model_info_; + MlirContext context_; + + std::unordered_map elem_type_map_; + std::unordered_map asm_type_map_; + std::vector shared_dims_; +}; + +/// Imports graph nodes into a function. +class NodeImporter { +public: + NodeImporter(GraphInfo &graph_info, ContextCache &cc, + MlirOperation module_op); + + /// Called after construction to define the function in the module. Must be + /// called prior to importing nodes. + Status DefineFunction(std::optional name = {}); + + /// Imports all nodes topologically. + Status ImportAll(); + + void DebugDumpModule(); + +private: + void PopulateGraphAttrs(MlirOperation container_op); + Status ImportInitializer(const onnx::TensorProto &initializer); + Status ImportNode(const onnx::NodeProto &node); + MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr); + + // Special-form nodes. + Status ImportGeneralNode(const onnx::NodeProto &node); + Status ImportConstantOfShapeNode(const onnx::NodeProto &node); + + /// Looks for an initializer for `name` and attempts to treat it as a 1D + /// shape, filling `shape` if successful. Returns failure and sets an error + /// if not. + Status GetImmediateShapeTensor(const std::string &name, + std::vector &shape); + + Status SetError(std::string msg) { + return graph_info_.model_info().SetError(std::move(msg)); + } + + GraphInfo &graph_info_; + ContextCache &cc_; + MlirContext context_; + MlirOperation module_op_; + MlirOperation func_op_; + MlirBlock body_block_; + MlirLocation default_loc_; + std::unordered_map nv_map_; +}; + +} // namespace torch_mlir_onnx diff --git a/projects/onnx_c_importer/README.md b/projects/onnx_c_importer/README.md new file mode 100644 index 000000000000..571c6fd41cd8 --- /dev/null +++ b/projects/onnx_c_importer/README.md @@ -0,0 +1,7 @@ +# ONNX C Importer + +This project provides a C implementation of the `onnx_importer.py`, which is +the canonical source. It is provided as sample code for anyone who wishes to +integrate it into their system. By design, it only depends on the ONNX API +and the MLIR C API via the `mlir-c` headers. As such, it should be easy to +build into any system that already has those things by adding the sources. diff --git a/projects/onnx_c_importer/import-onnx-main.cpp b/projects/onnx_c_importer/import-onnx-main.cpp new file mode 100644 index 000000000000..58ebd98b6a70 --- /dev/null +++ b/projects/onnx_c_importer/import-onnx-main.cpp @@ -0,0 +1,103 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +// This main driver uses LLVM tool-making facilities and the support lib. +// The actual importer libraries, however, only depend on the C API so that +// they can be included in foreign projects more easily. + +#include "torch-mlir-c/Registration.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "OnnxImporter.h" + +#include "onnx/onnx_pb.h" + +#include +#include + +using namespace llvm; +using namespace torch_mlir_onnx; + +struct MlirState { + MlirState() { + context = mlirContextCreateWithThreading(false); + torchMlirRegisterAllDialects(context); + module = mlirModuleCreateEmpty(mlirLocationUnknownGet(context)); + } + ~MlirState() { + mlirModuleDestroy(module); + mlirContextDestroy(context); + } + + MlirContext context; + MlirModule module; +}; + +int main(int argc, char **argv) { + static cl::opt inputFilename( + cl::Positional, cl::desc(""), cl::init("-")); + + static cl::opt outputFilename("o", cl::desc("Output filename"), + cl::value_desc("filename"), + cl::init("-")); + + InitLLVM y(argc, argv); + cl::ParseCommandLineOptions(argc, argv, "torch-mlir-onnx-import-c"); + + // Open the input as an istream because that is what protobuf likes. + std::unique_ptr alloced_input_stream; + std::istream *input_stream = nullptr; + if (inputFilename == "-") { + errs() << "(parsing from stdin)\n"; + input_stream = &std::cin; + } else { + alloced_input_stream = std::make_unique( + inputFilename, std::ios::in | std::ios::binary); + if (!*alloced_input_stream) { + errs() << "error: could not open input file " << inputFilename << "\n"; + return 1; + } + input_stream = alloced_input_stream.get(); + } + + // Parse the model proto. + ModelInfo model_info; + if (!model_info.model_proto().ParseFromIstream(input_stream)) { + errs() << "Failed to parse ONNX ModelProto from " << inputFilename << "\n"; + return 2; + } + + if (failed(model_info.Initialize())) { + errs() << "error: Import failure: " << model_info.error_message() << "\n"; + model_info.DebugDumpProto(); + return 3; + } + model_info.DebugDumpProto(); + + // Import. + MlirState owned_state; + ContextCache cc(model_info, owned_state.context); + NodeImporter importer(model_info.main_graph(), cc, + mlirModuleGetOperation(owned_state.module)); + if (failed(importer.DefineFunction())) { + errs() << "error: Could not define MLIR function for graph: " + << model_info.error_message() << "\n"; + return 4; + } + if (failed(importer.ImportAll())) { + errs() << "error: Could not import one or more graph nodes: " + << model_info.error_message() << "\n"; + return 5; + } + importer.DebugDumpModule(); + + return 0; +} diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index f3ae621e466e..4c9727772d02 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -18,13 +18,16 @@ LinalgOnTensorsBackendTestConfig, StablehloBackendTestConfig, NativeTorchTestConfig, + OnnxBackendTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, TorchDynamoTestConfig, ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import LinalgOnTensorsOnnxBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend +from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend from .xfail_sets import ( LINALG_XFAIL_SET, @@ -36,7 +39,9 @@ LTC_XFAIL_SET, LTC_CRASHING_SET, TORCHDYNAMO_XFAIL_SET, - TORCHDYNAMO_CRASHING_SET + TORCHDYNAMO_CRASHING_SET, + ONNX_CRASHING_SET, + ONNX_XFAIL_SET, ) # Import tests to register them in the global registry. @@ -44,7 +49,7 @@ register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -53,10 +58,12 @@ def _get_argparse(): Meaning of options: "linalg": run through torch-mlir"s default Linalg-on-Tensors backend. "tosa": run through torch-mlir"s default TOSA backend. +"stablehlo": run through torch-mlir"s default Stablehlo backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. "torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors. +"onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path. """) parser.add_argument("-f", "--filter", default=".*", help=""" Regular expression specifying which tests to include in this run. @@ -90,8 +97,11 @@ def main(): if args.config == "linalg": config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = LINALG_XFAIL_SET - # See https://discord.com/channels/636084430946959380/742573221882364009/1216676777137672235 crashing_set = set(["ConvolutionModule2DTranspose_basic"]) + elif args.config == "stablehlo": + config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) + xfail_set = all_test_unique_names - STABLEHLO_PASS_SET + crashing_set = STABLEHLO_CRASHING_SET elif args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET @@ -116,6 +126,10 @@ def main(): config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = TORCHDYNAMO_XFAIL_SET crashing_set = TORCHDYNAMO_CRASHING_SET + elif args.config == "onnx": + config = OnnxBackendTestConfig(LinalgOnTensorsOnnxBackend()) + xfail_set = ONNX_XFAIL_SET + crashing_set = ONNX_CRASHING_SET do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set) available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt] diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index dcefe850cc4a..05343f20c1dd 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -16,8 +16,6 @@ print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison()) LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { - "Conv1dNoPaddingModule_basic", - "Conv1dNoPaddingTransposeModule_basic", "Conv1dNoPaddingGroupModule_basic", "RepeatInterleaveStaticModule_basic", "RepeatInterleaveFillModule_basic", @@ -29,19 +27,28 @@ # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic" + "IscloseStaticModuleTrue_basic", + "SplitWithSizes_Module_basic", } TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors + # torch._dynamo.exc.Unsupported: Tensor.item + "CumsumModule_basic", + # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2D_basic", + # Size result mismatch (exposed by downstream canonicalizer + # on incompatabile casts). + # https://github.com/pytorch/pytorch/issues/119407 + "ConvolutionBackwardModule2DStrided_basic", + # RuntimeError: Index tensor must have the same number of dimensions as self tensor # RuntimeError: Failed running call_function aten.nll_loss_backward(... # https://github.com/pytorch/pytorch/issues/89630 @@ -66,14 +73,6 @@ # See also: https://github.com/pytorch/torchdynamo/issues/327 "AtenEmbeddingBagSumExample_basic", - # error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal - "BernoulliFloatModule_basic", - "BernoulliPModule_basic", - # error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal - "ElementwiseFlattenBroadcastModule_basic", - "FlattenRank0Module_basic", - "UniformModule_basic", - "UniformStaticShapeModule_basic", # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", @@ -107,6 +106,7 @@ # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} 'AtenSubFloatModule_basic', + 'AtenMulFloatModule_basic', 'BoolFloatFalseModule_basic', 'BoolFloatTrueModule_basic', 'CeilFloatModule_basic', @@ -116,6 +116,7 @@ 'GtFloatIntModule_basic', 'NeFloatIntModule_basic', 'SubFloatModule_basic', + 'MulFloatModule_basic', 'TensorToFloatZeroRank_basic', 'TensorToFloat_basic', # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} @@ -140,6 +141,10 @@ 'ViewCollapseDynamicWithAtenSizeIntModule_basic', # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} + # ERROR: torch._dynamo.exc.Unsupported: Tensor.item + 'AtenItemIntOpModule_basic', + 'AtenItemFpOpModule_basic', + # ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)} 'SortIntListReverse_basic', @@ -218,10 +223,6 @@ 'ConstantBoolParameterModule_basic', # START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "AddCDivModule_basic", - "ElementwiseMulScalarModule_basic", - "ElementwiseMulScalarModule_float", - "NativeGroupNormBackwardModule_basic", "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", @@ -229,22 +230,7 @@ # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalarInt64Module_basic", - "ElementwiseAddScalarIntModule_basic", - "MobilenetV3Module_basic", - "NativeBatchNorm1DModule_basic", - "NativeBatchNorm2DModule_basic", - "NativeBatchNorm3DModule_basic", - "NativeBatchNormNoneWeightModule_basic", - "NativeGroupNormModule_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", # END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' @@ -257,9 +243,6 @@ # ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' "ElementwiseAtenDivIntScalarModule_basic", - # ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' - "ElementwiseMulScalarModule_int", - # ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", @@ -267,6 +250,8 @@ # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode "ElementwiseDivRoundingModeFloorModule_basic", "ElementwiseDivRoundingModeTruncModule_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", # ERROR: Exception: Unsupported op: get_attr "NumToTensorFloatModule_basic", @@ -294,16 +279,11 @@ "RepeatInterleaveFillModule_basic", # failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal - "Conv1dNoPaddingModule_basic", - "Conv1dNoPaddingTransposeModule_basic", "Conv1dNoPaddingGroupModule_basic", # tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", - # failed to legalize operation 'torch.aten.clamp' that was explicitly marked illegal - "ElementwiseClampIntModule_basic", - # failed to legalize operation 'torch.constant.int' "RepeatInterleaveStaticModule_basic", @@ -317,8 +297,7 @@ # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - # Exception: Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: _scaled_dot_product_flash_attention; - "ScaledDotProductAttentionSameModule_basic", + # AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu "ScaledDotProductAttentionDifferentModule_basic", # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only @@ -340,6 +319,23 @@ # ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4])) "ArangeStartOutViewModule_basic", + + # Dynamo does not support tracing quantized tensors + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "AtenMmQuint8_basic", + "Conv2dQInt8Module_basic", + + # Dynamo not supporting conv_tbc + "ConvTbcModule_basic", + + "FloatImplicitModule_basic", + "IntImplicitModule_basic", + + # Others + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", } if torch_version_for_comparison() <= version.parse("2.2.0"): @@ -381,87 +377,30 @@ "IndexPutImpl2DNoneIndexStaticModule_basic", # See https://discord.com/channels/636084430946959380/742573221882364009/1216676777137672235 "ConvolutionModule2DTranspose_basic", + + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dStaticModule_basic", + + # Looks like incorrect fx graph conversion + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", } STABLEHLO_PASS_SET = { - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddIntModule_basic", - "AtenIntBoolOpModule_basic", - "AtenIntTensorByteDtypeModule_basic", - "AtenIntTensorCharDtypeModule_basic", - "BoolFloatFalseModule_basic", - "BoolFloatTrueModule_basic", - "BoolIntFalseModule_basic", - "BoolIntTrueModule_basic", - "CeilFloatModule_basic", - "DivFloatModule_basic", - "DivIntModule_basic", - "EqIntModule_basic", - "GeFloatIntModule_basic", - "GeFloatModule_basic", - "GeIntModule_basic", - "GtFloatIntModule_basic", - "GtIntModule_basic", - "MulIntModule_basic", - "NeFloatIntModule_basic", - "NeIntModule_basic", - "SqrtIntModule_basic", - "SubFloatModule_basic", - "SubIntModule_basic", - "TensorToBoolZeroRank_basic", - "TensorToIntZeroRank_basic", - "TensorToFloatZeroRank_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "AliasModule_basic", - "TensorIntModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", - "AtenIntBoolOpConstFalseModule_basic", - "AtenIntBoolOpConstTrueModule_basic", - "AtenFloatScalarModule_basic", - "ScalarImplicitFloatModule_basic", - "ScalarImplicitIntModule_basic", - "AtenSubFloatModule_basic", - "BoolFloatConstantModule_basic", - "BoolIntConstantModule_basic", - "ContainsIntList_False", - "ContainsIntList_True", - "IntFloatModule_basic", - "IsFloatingPointFloat_True", - "IsFloatingPointInt_False", - "LenStrModule_basic", - "MeanDimAllReduceKeepdimModule_basic", - "MeanDimAllReduceModule_basic", - "MeanDimDtypeModule_basic", - "MeanDimKeepdimModule_basic", - "MeanDimModule_basic", - "MeanDimNegativeModule_basic", - "NumelZeroRankModule_basic", - "PowIntFloatModule_basic", - "PrimMaxIntModule_basic", - "PrimMinIntModule_basic", - "PrimMinIntDynamicModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", - "SqrtIntConstantModule_basic", - "StdBiasedModule_basic", - "StdDimBiasedModule_basic", - "TestMultipleTensorAndPrimitiveTypesReturn_basic", - "VarBiasedModule_basic", - "VarDimBiasedModule_basic", - "VarMeanBiasedModule_basic", - "VarMeanDimBiasedModule_basic", - "ConstantBoolParameterModule_basic", - "MaskedFillScalarIntValueStaticModule_basic", - "MaskedFillScalarFloatValueStaticModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AddSizeIntModule_basic", - "AddSizeIntNegDimModule_basic", "ArangeDtypeFloatModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", @@ -473,143 +412,176 @@ "ArangeStartIntModule_basic", "ArangeStartNegativeStepFloatModule_basic", "ArangeStartNegativeStepIntModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutModule_basic", + "ArangeStartOutViewModule_basic", "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", - "BatchMlpLayerModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", - "ResNet18StaticModule_basic", + "ArgmaxModule_with_dim", + "AtenComplex64Module_basic", + "AtenEyeMModuleCPUDevice_basic", + "AtenEyeMModuleDefaultDtype_basic", + "AtenEyeMModuleFalsePinMemory_basic", + "AtenEyeMModuleFloat2D_basic", + "AtenEyeMModuleInt2D_basic", + "AtenEyeModuleCPUDevice_basic", + "AtenEyeModuleDefaultDtype_basic", + "AtenEyeModuleFalsePinMemory_basic", + "AtenEyeModuleFloat2D_basic", + "AtenEyeModuleInt2D_basic", + "AtenFloatScalarModule_basic", + "AtenInstanceNormModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", + "AtenRoundIntModule_basic", + "AtenSubFloatModule_basic", + "AtenToDeviceModule_basic", "AtenToDtypeModule_basic", - "BmmFloatModule_basic", - "BmmIntModule_basic", - "BroadcastToModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dStaticModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "BaddbmmStaticModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorReturnTrueModule_basic", + "BroadcastListConstructWithMinusOneModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastToDifferentRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", - "BroadcastListConstructWithMinusOneModule_basic", "BroadcastDifferentRankSameFinalShapeModule_basic", "BroadcastDifferentRankWithMinusOneModule_basic", "BroadcastToDifferentRankNotOneStaticModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", + "CeilFloatModule_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "CloneModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "ContiguousModule_basic", + "Conv1dNoPaddingGroupModule_basic", + "Conv1dNoPaddingModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Convolution2DStaticModule_basic", + "Convolution2DGroupsStatic_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CosineSimilarityStaticModule_basic", + "CumsumInputDtypeInt32Module_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", - "CosineSimilarityStaticModule_basic", - "CosineSimilarityStaticBroadcastModule_basic", "DetachModule_basic", - "ElementwiseIsnanModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "DropoutEvalFloatModule_basic", + "DropoutEvalIntModule_basic", + "DropoutTrainStaticShapeModule_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenWhereSelfModule_basic", - "ElementwiseWhereScalarOtherStaticModule_basic", - "ElementwiseWhereScalarSelfStaticModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", - "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampMaxModule_basic", + "ElementwiseCeilModule_basic", "ElementwiseClampIntModule_basic", - "ElementwiseSignModule_basic", - "ElementwisePowModule_basic", - "ElementwisePowTensorStaticModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampTensorInt8Module_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseErfModule_basic", "ElementwiseExpModule_basic", - "ElementwiseFlattenBroadcastModule_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseEluModule_basic", - "ElementwiseEluNonDefaultModule_basic", + "ElementwiseFloorIntModule_basic", + "ElementwiseFloorModule_basic", + "ElementwiseGeluModule_basic", + "ElementwiseGeluApproximateTanhModule_basic", + "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLogModule_basic", + "ElementwiseNanToNumModule_Basic", + "ElementwiseNeFloatTensorStaticModule_basic", + "ElementwiseNeIntTensorStaticModule_basic", "ElementwiseNegModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseReluModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", - "ElementwiseSqrtModule_basic", "ElementwiseSinModule_basic", - "ElementwiseCosModule_basic", - "ElementwiseCeilModule_basic", - "ElementwiseFloorModule_basic", - "ElementwiseUnaryModule_basic", - "ElementwiseUnsqueezeBroadcastModule_basic", - "ElementwiseUnsqueezeNegDimsModule_basic", + "ElementwiseSqrtModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", - "ElementwiseAddModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalarInt64Module_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseAtenDivIntScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseNeFloatScalarModule_basic", - "ElementwiseNeFloatTensorStaticModule_basic", - "ElementwiseNeIntTensorStaticModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseErfModule_basic", - "ElementwiseGeluModule_basic", - "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", - "ElementwiseGeFloatIntScalarModule_basic", - "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeIntScalarModule_basic", - "ElementwiseGeMixedIntScalarModule_basic", - "ElementwiseLeakyReluStaticModule_basic", - "ElementwiseLeFloatIntScalarModule_basic", - "ElementwiseLeFloatScalarModule_basic", - "ElementwiseLeIntScalarModule_basic", - "ElementwiseLeMixedIntScalarModule_basic", - "ElementwiseLtDiffWidthScalarModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", - "ElementwiseMulScalarModule_basic", - "ElementwiseMulScalarModule_float", - "ElementwiseMulScalarModule_int", - "ElementwiseNeIntScalarModule_basic", - "ElementwiseReciprocalModule_basic", - "ElementwiseRelu6Module_basic", - "ElementwiseReluModule_basic", - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Float_basic", - "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseSubScalarIntModule_basic", - "ElementwiseWhereScalarModule_basic", - "ElementwiseAbsModule_basic", - "EmbeddingModule1DIndices_basic", - "EmbeddingModuleI32Static_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "EmbeddingModuleF16_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeIdentityModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", + "EmptyModule_uint8", "EmptyLikeMemoryFormatModule_basic", "EmptyLikeModule_defaultDtype", "EmptyLikeModule_falsePinMemory", "EmptyLikeModule_float", "EmptyLikeModule_int", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EmptyStridedModule_basic", + "EyeStaticModule_basic", + "EqIntModule_basic", "ExpandAsIntModule_basic", - "ExpandModule_basic", - "EinsumStaticModule_basic", - "EinsumStaticFourDimensionModule_basic", - "EinsumStaticContractRhsModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64_basic", - "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithInt64Static_basic", + "Fill_TensorFloat64WithInt64_basic", + "FlattenRank0Module_basic", + "FlattenStaticModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "FullLikeModuleDefaultDtype_basic", @@ -626,197 +598,69 @@ "FullModuleFloat3D_basic", "FullModuleInt2D_basic", "FullModuleInt3D_basic", - "NewFullModuleDefaultDtype_basic", - "NewFullModuleFalsePinMemory_basic", - "NewFullModuleFloat2D_basic", - "NewFullModuleFloat3DStatic_basic", - "NewFullModuleFloat3D_basic", - "NewFullModuleInt2DStatic_basic", - "NewFullModuleInt2D_basic", - "NewFullModuleInt3D_basic", "GatherStaticModule_basic", - "GatherModule_basic", - "Gather2DInputModdule_basic", - "GatherRandomIndexModule_basic", - "GatherNegativeDimModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", "GeluBackwardModule_basic", - "HardswishModule_basic", - "HardswishRandomModule_basic", - "HardTanhIntModule_basic", - "HardTanhModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", - "IndexSelectNegativeDimModule_basic", - "IndexSelectStaticModule_basic", - "IndexTensorStaticModule_basic", + "GluStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", "IndexTensorModule3dInputStatic_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "IntFloatModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", "LayerNormLastDimModule_basic", "LayerNormModule_basic", "LayerNormNormalizeOverAllDimsModule_basic", "LeakyReluBackwardStaticModule_basic", - "LinalgVectorNormModule_basic", - "LinalgVectorNormKeepDimModule_basic", - "MatmulBroadcastBatchDim_basic", - "MatmulSingleDynamicBatchDim_basic", - "Matmul_3d", - "Matmul_4d", - "MeanDimEmptyDimModule_basic", - "MeanDtypeModule_basic", - "MeanDynamicSizesModule_basic", - "MeanLargeInputModule_basic", - "MeanModule_basic", - "Mlp1LayerModule_basic", - "Mlp2LayerModule_basic", - "MmTanhModule_basic", - "Mv_basic", - "NativeLayerNormModule4D_basic", - "NativeLayerNormModule_basic", - "OneHotModule_basic", - "PrimsConvertElementTypeModule_basic", - "ReduceFrobeniusNormKeepDimModule_basic", - "ReduceSumDimIntListElementTypeBoolModule_basic", - "ReduceSumElementTypeBoolModule_basic", - "ReduceSumDimIntListEmptyDimModule_basic", - "ReduceSumDimIntListDtypeFloatModule_basic", - "ReduceSumDimIntListDtypeIntModule_basic", - "ReduceSumDimIntListKeepDimFloatModule_basic", - "ReduceSumDimIntListKeepDimIntModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "ReduceL1NormModule_basic", - "ReduceL1NormWithDTypeModule_basic", - "ReduceL2NormModule_basic", - "ReduceL3NormAllDimsModule_basic", - "ReduceL3NormKeepDimModule_basic", - "ReduceLN3NormModule_basic", - "NormScalarOptDimKeepDimModule_basic", - "NormScalarOptDimModule_basic", - "NormalizeModule_basic", - "ScalarConstantTupleModule_basic", - "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SliceSingleIdxModule_basic", - "SqueezeDimModule_dynamic", - "SqueezeDimModule_negDim", - "ToCopyBoolDTypeStaticModule_basic", - "ToCopyModule_basic", - "ToCopyWithDTypeFalsePinMemoryModule_basic", - "ToCopyWithDTypeModule_basic", - "ReduceFrobeniusNormModule_basic", - "FlattenStaticModule_basic", - "FlattenRank0Module_basic", - "TensorsConcatNegativeDimModule_basic", - "TensorsConcatPromoteDTypeModule_basic", - "TensorsConcatStaticModule_basic", - "TensorsConcatNegativeDimStaticModule_basic", - "TensorsConcatPromoteDTypeStaticModule_basic", - "TensorsStackModule_basic", - "TensorsStackNegativeDimModule_basic", - "TensorsStackPromoteDTypeModule_basic", + "LenStrModule_basic", "LiftFreshCopyModule_basic", - "Mlp2LayerModuleNoBias_basic", - "NumelModule_basic", - "SiluModule_basic", - "SquareModule_basic", - "SqueezeModule_allUnitDim", - "SqueezeDimModule_unitDim", - "ViewCollapseOnesMiddleModule_basic", - "ViewDoubleMergeStaticModule_basic", - "ViewExpandDynamicDimModule_basic", - "ViewFlattenAndExpandModule_basic", - "ViewFiveTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", - "ViewExpandOnesMiddleOppModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "NumToTensorFloatModule_basic", - "AtenToDeviceModule_basic", - "AvgPool1dStaticModule_basic", - "AvgPool2dStaticModule_basic", - "Conv1dNoPaddingModule_basic", - "Conv1dNoPaddingGroupModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "Conv2dWithPaddingDilationStrideStaticModule_grouped", - "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "Convolution2DStaticModule_basic", - "ConvolutionModule2DTransposeStridedStatic_basic", - "Convolution2DGroupsStatic_basic", - "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", - "ElementwiseCloneModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ReturnThreeTensorFloat32_basic", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "SqueezeModule_static", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "UnflattenStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "RsubIntModule_basic", - "RsubIntModule_noalpha_basic", - "RsubIntStaticModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", - "ScalarTensorDefaultDtypeModule_basic", - "ScalarTensorFloat32Module_basic", - "ScalarTensorInt32Module_basic", - "ScalarTensorInt64Module_basic", - "SelectScattertModule_basic", - "SelectScattertStaticModule_basic", - "SliceStaticModule_basic", - "SliceModule_basic", - "SliceNegIdxModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceOutOfLowerBoundStartIndexStaticModule_basic", - "SliceOutOfUpperBoundIndexModule_basic", - "SliceOutOfUpperBoundIndexStaticModule_basic", - "SliceStartEqEndModule_basic", - "SliceSizeTwoStepModule_basic", - "SliceSizeTwoStepDivisibleStaticModule_basic", - "SliceWholeTensorModule_basic", - "SliceScatterModule_basic", - "SliceScatterNegativeDimModule_basic", - "SliceScatterNegativeEndModule_basic", - "SliceScatterStaticModule_basic", - "SliceEndSleStartStaticModule_basic", - "SliceScatterStepVariationModule_basic", - "SliceScatterZeroDimModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeModule_broadcast", - "ReturnTwoTensorF32I64_basic", + "MaskedFillScalarFloatValueStaticModule_basic", + "MaskedFillScalarIntValueStaticModule_basic", "Matmul4dStatic_basic", - "Matmul_dot", + "Matmul4dStaticBroadcast_basic", "Matmul_2d", + "Matmul_dot", "Matmul_matvec", "Matmul_vecmat", + "MaxPool2dStaticModule_basic", "MaxPool2dWithIndicesStaticModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MeanDimAllReduceModule_basic", + "MeanDimEmptyDimModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "Mlp2LayerModuleNoBias_basic", "MmDagModule_basic", "MmModule_basic", "MmModule_chained", - "MaxPool2dStaticModule_basic", - "EmptyModule_contiguous", - "EmptyModule_defaultDtype", - "EmptyModule_falsePinMemory", - "EmptyModule_int", - "EmptyModule_float", + "MmTanhModule_basic", + "MoveDimIntModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "Mv_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormModule4D_basic", + "NativeLayerNormModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", "NewEmptyModuleBool_basic", "NewEmptyModuleDefaultDtype_basic", "NewEmptyModuleFalsePinMemory_basic", @@ -828,120 +672,177 @@ "NewEmptyModuleNonDefaultFloatDtype_basic", "NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic", - "EmptyStridedModule_basic", - "EmptyStridedSizeIntStrideModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", - "ZerosLikeModule_defaultDtype", - "ZerosLikeModule_falsePinMemory", - "ZerosLikeModule_float", - "ZerosLikeModule_int", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", - "OnesLikeModule_defaultDtype", - "OnesLikeModule_falsePinMemory", - "OnesLikeModule_float", - "OnesLikeModule_int", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", + "NewFullModuleDefaultDtype_basic", + "NewFullModuleFalsePinMemory_basic", + "NewFullModuleFloat3DStatic_basic", + "NewFullModuleFloat3D_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt3D_basic", "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", + "NewOnesModuleFalsePinMemory_basic", "NewOnesModuleFloat2D_basic", "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", "NewZerosStaticModuleLayoutStrided_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "DropoutTrainStaticShapeModule_basic", - "NativeDropoutEvalFloatModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewCollapseModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewDynamicExpandCollapseModule_basic", - "ViewDynamicExpandModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChangeStaticModule_basic", - "ViewNoChange1dModule_basic", - "ViewNoChange2dModule_basic", - "ViewNoChange3dModule_basic", - "UnsafeViewExpandModule_basic", + "NormalizeModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "NumpyTRank0Module_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNDynamicModule_basic", + "NumpyTRankNStaticModule_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "OnesModuleCPUDevice_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleFloat_basic", + "OnesModuleInt_basic", + "Permute0RankModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsSumFloatModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "RandIntDtypeModule_basic", + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RandModule_basic", + "ReduceAmaxMultiDim_basic", + "ReduceAmaxOutOfOrderDim_basic", + "ReduceAmaxSingleDim_basic", + "ReduceFrobeniusNormModule_basic", "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimSignedInt_basic", + "ReduceMaxAlongDim_basic", "ReduceMaxFloatModule_basic", "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", - "PrimsSumFloatModule_basic", "ReduceMinFloatModule_basic", "ReduceMinSignedIntModule_basic", "ReduceMinUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", "RepeatModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", - "ReshapeExpandModule_basic", "ReshapeAsModule_basic", - "TestMultipleTensorReturn_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "BaddbmmStaticModule_basic", - "BaddbmmBroadcast1DInputModule_basic", - "BaddbmmBroadcast2DInputModule_basic", - "NarrowHorizontalTest2_basic", - "NarrowHorizontalTest_basic", - "NarrowVerticalTest2_basic", - "NarrowVerticalTest_basic", - "NarrowTensorHorizontalModule_basic", - "NarrowTensorVerticalModule_basic", - "NumToTensorIntModule_basic", - "NumpyTRank0Module_basic", - "NumpyTRank1Module_basic", - "NumpyTRank2Module_basic", - "NumpyTRankNStaticModule_basic", - "NumpyTRankNDynamicModule_basic", - "TensorsSplitTensorModule_basic", - "TensorsSplitTensorNegativeDimModule_basic", - "TensorsSplitTensorLastSmallerModule_basic", + "ReshapeExpandModule_basic", + "ReturnThreeTensorFloat32_basic", + "ReturnTwoTensorF32I64_basic", + "RollModule_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScalarTensorDefaultDtypeModule_basic", + "ScalarTensorFloat32Module_basic", + "ScalarTensorInt32Module_basic", + "ScalarTensorInt64Module_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SelectScattertStaticModule_basic", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceEndSleStartStaticModule_basic", + "SliceOutOfLowerBoundStartIndexStaticModule_basic", + "SliceSizeTwoStepDivisibleStaticModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStartEqEndModule_basic", + "SliceStaticModule_basic", + "SliceWholeTensorModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SqueezeDimModule_identity", + "SqueezeDimModule_static", + "SqueezeDimModule_unitDim", + "SqueezeModule_allUnitDim", + "SqueezeModule_static", + "SubFloatModule_basic", + "SubIntModule_basic", + "TModuleRank0_basic", + "TModuleRank1_basic", "TModuleRank2_basic", + "TensorIntModule_basic", "TensorLiteralModule_basic", - "TensorsConcatModule_basic", "TensorOpaqueLiteralModule_basic", - "TransposeIntModule_basic", - "TransposeIntNegDimsModule_basic", - "ToDtypeBoolLayoutNoneModule_basic", + "TensorToBoolZeroRank_basic", + "TensorToFloatZeroRank_basic", + "TensorToIntZeroRank_basic", + "TensorsConcatModule_basic", + "TensorsConcatNegativeDimModule_basic", + "TensorsConcatNegativeDimStaticModule_basic", + "TensorsConcatPromoteDTypeModule_basic", + "TensorsConcatPromoteDTypeStaticModule_basic", + "TensorsConcatStaticModule_basic", + "TensorsSplitTensorLastSmallerModule_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", + "TestF16Return_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "TestMultipleTensorReturn_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "ToCopyBoolDTypeStaticModule_basic", "ToDtypeBoolLayoutNoneStaticModule_basic", + "ToDtypeLayoutCPUModule_basic", "ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutStridedModule_basic", - "TypeAsSameModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + "TupleModule_basic", "TypeAsDifferentModule_basic", + "TypeAsSameModule_basic", "TypeConversionF32ToF64Module_basic", "TypeConversionF64ToF32Module_basic", "TypeConversionI1ToF32Module_basic", @@ -950,63 +851,69 @@ "TypeConversionI1ToI64Module_basic", "TypeConversionI32ToI64Module_basic", "TypeConversionI64ToI32Module_basic", - "TypePromotionAlphaWiderModule_basic", - "TypePromotionSameCategoryZeroRankWider_basic", - "TypePromotionZeroRankHigherCategoryModule_basic", - "OnesModuleCPUDevice_basic", - "Permute0RankModule_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeView1DFoldModule_basic", "UnsafeViewCollapseModule_basic", "UnsafeViewDynamicExpandModule_basic", - "AtenRoundIntModule_basic", - "TestF16Return_basic", - "_LogSoftmaxModuleStable_basic", - "PrimsSqueezeModule_basic", - "PrimsSqueezeEmptyDimensionsModule_basic", - "MoveDimIntModule_basic", - "MoveDimIntNegativeIndexModule_basic", - "ConvolutionBackwardModule2DStatic_basic", - "ConvolutionBackwardModule2DStrided_basic", - "PrimsViewOfModule_basic", - "PrimsViewOfZeroRankModule_basic", - "AtenComplex64Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitTensorLastSmallerModule_basic", - "SplitWithSizesListUnpackModule_basic", - "UnbindIntListUnpack_Module_basic", - "UnbindIntGetItem_Module_basic", - "ChunkListUnpack_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "RandIntDtypeModule_basic", - "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", - "RandIntModule_basic", - "RandIntPinMemoryModule_basic", - "RandModule_basic", - "UniformStaticShapeModule_basic", - "UniformNoCorrelationModule_basic", - "TupleModule_basic", - "AtenEmbeddingBagStaticModule_basic", + "UnsafeViewExpandModule_basic", + "View1DFoldModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewCollapseModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandOnesModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleFalsePinMemory_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "LinspaceDtypeModule_basic", + "LinspaceEmptyModule_basic", + "LinspaceModule_basic", + "LinspaceOneSizeModule_basic", + "LinspaceTwoSizeModule_basic", } -STABLEHLO_CRASHING_SET = { - # These e2e tests crash because currently mlir-hlo's shape-component-analysis - # only support exact one index in tensor::ExtractOp when it's related with - # some tensors' shape. REF: - # https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586 - # FIXME if upstream mlir-hlo fix this. - "ViewCollapseDynamicWithAtenSizeIntModule_basic", - "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", - - "Aten_EmbeddingBagExample_basic", - "AtenEmbeddingBagSumExample_basic" +STABLEHLO_CRASHING_SET = { + "AtenEmbeddingBagSumExample_basic", } # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDiv_Module_basic", "AddCDivModule_basic", @@ -1031,6 +938,14 @@ "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartStepFloatModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", @@ -1045,6 +960,7 @@ "AtenEyeModuleFloat2D_basic", "AtenEyeModuleInt2D_basic", "AtenRoundIntModule_basic", + "AtenInstanceNormModule_basic", "AtenToDeviceModule_basic", "AtenToDtypeModule_basic", "BaddbmmBroadcast1DInputModule_basic", @@ -1073,6 +989,8 @@ "BroadcastZeroRankInputStaticModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", + "CloneModule_basic", + "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", "ChunkListUnpackUneven_Module_basic", "ConstantBoolParameterModule_basic", @@ -1094,6 +1012,8 @@ "Conv2dWithPaddingModule_basic", "Convolution2DGroupsStatic_basic", "Convolution2DStaticModule_basic", + "CosineSimilarityStaticModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", "DetachModule_basic", "DropoutEvalFloatModule_basic", "DropoutEvalIntModule_basic", @@ -1101,7 +1021,8 @@ "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - "ElementwiseAbsModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", "ElementwiseAddModule_basic", "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", @@ -1124,6 +1045,18 @@ "ElementwiseAtenLogicalXorOpModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", + "ElementwiseAtenLogicalOrOpBrodcastModule_basic", + "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", + "ElementwiseAtenLogicalOrOpModule_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", + "ElementwiseAtenLogicalOrOpRandomModule_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseBinaryStaticShapeModule_basic", @@ -1140,10 +1073,13 @@ "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampModule_basic", + "ElementwiseClampTensorInt8Module_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", "ElementwiseDivScalarModule_basic", + "ElementwiseDivTensorIntegerModule_basic", + "ElementwiseDivTensorUnsignedIntegerModule_basic", "ElementwiseEluModule_basic", "ElementwiseEluNonDefaultModule_basic", "ElementwiseEqBoolScalarModule_basic", @@ -1170,6 +1106,8 @@ "ElementwiseGtIntTensorModule_basic", "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIsinfModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseIsnanModule_basic", "ElementwiseLeakyReluModule_basic", "ElementwiseLeakyReluStaticModule_basic", @@ -1180,6 +1118,8 @@ "ElementwiseLeIntScalarModule_basic", "ElementwiseLeIntTensorModule_basic", "ElementwiseLeMixedIntScalarModule_basic", + "ElementwiseLerpScalarIntModule_basic", + "ElementwiseLerpScalarFloatModule_basic", "ElementwiseLog2Module_basic", "ElementwiseLogModule_basic", "ElementwiseLtDiffWidthScalarModule_basic", @@ -1220,6 +1160,7 @@ "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRsqrtModule_basic", + "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSignModule_basic", "ElementwiseSqrtIntModule_basic", @@ -1231,6 +1172,7 @@ "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseWhereScalarModule_basic", + "ElementwiseNanToNumModule_Basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "EmptyModule_contiguous", @@ -1281,6 +1223,9 @@ "LiftFreshCopyModule_basic", "_LogSoftmaxModule_basic", "_LogSoftmaxModuleStable_basic", + "LinalgVectorNormKeepDimModule_basic", + "LinalgVectorNormModule_basic", + "LinalgNormKeepDimModule_basic", "MaskedFillScalarDefaultModule_basic", "MaskedFillScalarFloatValueModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", @@ -1289,6 +1234,7 @@ "MaskedFillTensorIntValueStaticModule_basic", "Matmul_3d", "Matmul4dStatic_basic", + "Matmul4dStaticBroadcast_basic", "Matmul_dot", "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", @@ -1330,6 +1276,11 @@ "NewZerosModuleInt2D_basic", "NewZerosModuleInt3D_basic", "NewZerosStaticModuleLayoutStrided_basic", + "NormalizeModule_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", "NumpyTRank2Module_basic", @@ -1388,6 +1339,7 @@ "SelectIntNegativeDimAndIndexStaticModule_basic", "SiluModule_basic", "SliceOutOfLowerBoundStartIndexStaticModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceSizeTwoStepDivisibleStaticModule_basic", "SliceStaticModule_basic", "SoftmaxIntModule_basic", @@ -1473,6 +1425,12 @@ "ZerosModuleFloat3D_basic", "ZerosModuleInt2D_basic", "ZerosModuleInt3D_basic", + "_LogSoftmaxModuleStable_basic", + "_LogSoftmaxModule_basic", + "_SoftmaxModule_basic", + "LinspaceModule_basic", + "LinspaceOneSizeModule_basic", + "LinspaceTwoSizeModule_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { @@ -1486,7 +1444,11 @@ "EyeStaticModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "CosineSimilarityModule_basic", "NativeGroupNormBackwardModule_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceFrobeniusNormModule_basic", "SliceWholeTensorModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", @@ -1500,6 +1462,7 @@ "NormalizeModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", + "SliceEndSleStartStaticModule_basic", }) - { ### Test failing in make_fx_tosa but not in tosa @@ -1513,6 +1476,8 @@ "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", + + "AtenInstanceNormModule_basic", } MAKE_FX_TOSA_CRASHING_SET = {"CumsumModule_basic"} @@ -1536,6 +1501,7 @@ "PixelShuffleModuleFullDynamic_basic", "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", + "ConvTbcModule_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", @@ -1545,6 +1511,7 @@ "_ConvolutionDeprecated2DBenchmarkModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", "AddIntModule_basic", "ArangeStartOutViewModule_basic", "AtenIntBoolOpModule_basic", @@ -1559,6 +1526,7 @@ "CeilFloatModule_basic", "DivFloatModule_basic", "EqIntModule_basic", + "ExponentialModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", @@ -1591,6 +1559,7 @@ "SliceStartEqEndModule_basic", "SqrtIntModule_basic", "SubFloatModule_basic", + "MulFloatModule_basic", "SubIntModule_basic", "TensorsStackPromoteDTypeModule_basic", "TensorToBoolZeroRank_basic", @@ -1604,8 +1573,11 @@ "ViewCollapseDynamicWithAtenSizeIntModule_basic", "AtenEmbeddingBagSumExample_basic", "Aten_EmbeddingBagExample_basic", + "ElementwiseLogitModule_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseLerpScalarIntModule_basic", + "ElementwiseLerpScalarFloatModule_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", "UpSampleNearest2dBackwardVec_basic", @@ -1619,6 +1591,7 @@ "VarMeanUnbiasedModule_basic", "RandnLikeModule_basic", "RandnLikeDtypeModule_basic", + "NormalFunctionalModule_basic", "BernoulliFloatModule_basic", "BernoulliModule_basic", "BernoulliPModule_basic", @@ -1655,5 +1628,684 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseIsinfModule_basic", + "Conv2dQInt8Module_basic", } + +ONNX_XFAIL_SET = { + # Failure - cast error + "PermuteNegativeIndexModule_basic", + + # Failure - incorrect numerics + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseSeluModule_basic", + "FlipModuleStaticShape_basic", + "FlipNegativeIndexModule_basic", + "HardsigmoidModule_basic", + "HardsigmoidRandomModule_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "ResNet18Module_basic", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopy_Module_basic", + "TupleModule_basic", + + # Failure - incorrect shape + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutViewModule_basic", + "BroadcastDynamicDimModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "ViewSizeFromOtherTensor_basic", + + # Failure - onnx_export + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveMaxPool2dDynamicWithIndices_basic", + "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dStaticWithIndices_basic", + "AdaptiveMaxPool2dStatic_basic", + "AddCDivModule_basic", + "AddIntModule_basic", + "Add_Module_basic", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "AtenComplex64Module_basic", + "AtenComplexImagModule_basic", + "AtenComplexRealModule_basic", + "AtenComplexViewModule_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenMmQuint8_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "AtenSubFloatModule_basic", + "AtenTopKModule_basic", + "AtenTopKSmallestModule_basic", + "Aten_EmbeddingBagExample_basic", + "AvgPool2dWithoutPadModule_basic", + "BatchMlpLayerModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "CeilFloatModule_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "CollapseAllDimensionsModule_basic", + "CollapseFullDynamicModule_basic", + "CollapsePartialDynamicModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv1dModule_basic", + "Conv2dBiasNoPaddingModule_basic", + "Conv2dModule_basic", + "Conv2dNoPaddingModule_basic", + "Conv2dQInt8Module_basic", + "Conv2dWithPaddingDilationStrideModule_basic", + "Conv2dWithPaddingModule_basic", + "Conv3dModule_basic", + "ConvTbcModule_basic", + "Conv_Transpose2dModule_basic", + "Convolution2DModule_basic", + "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "ConvolutionModule2DGroups_basic", + "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", + "ConvolutionModule2DTransposeStrided_basic", + "ConvolutionModule2DTranspose_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "ElementwiseMulTensorComplexModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", + "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", + "EqIntModule_basic", + "ExponentialModule_basic", + "FloatImplicitModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GeluBackwardModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HardtanhBackward_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "IntFloatModule_basic", + "IntImplicitModule_basic", + "IouOfModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "IscloseStaticModuleTrue_basic", + "IscloseStaticModule_basic", + "LeakyReluBackwardModule_basic", + "LeakyReluBackwardStaticModule_basic", + "LenStrModule_basic", + "LiftFreshCopyModule_basic", + "LogSoftmaxBackwardModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MeanDimEmptyDimModule_basic", + "Mlp1LayerModule_basic", + "Mlp2LayerModuleNoBias_basic", + "Mlp2LayerModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeGroupNormBackwardModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormDynamicModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NllLossModule_1D_basic", + "NllLossModule_basic", + "NllLossModule_ignore_index_out_of_bounds_basic", + "NllLossModule_mean_basic", + "NllLossModule_sum_basic", + "NormScalarModule_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", + "NormalFunctionalModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "RandIntDtypeModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeExpandModule_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScatterReduceFloatMaxModule", + "ScatterReduceFloatMeanModule", + "ScatterReduceFloatMeanModuleIncludeSelf", + "ScatterReduceFloatMinModule", + "ScatterReduceFloatProdModule", + "ScatterReduceFloatSumModule", + "ScatterReduceIntMaxModule", + "ScatterReduceIntMeanModule", + "ScatterReduceIntMeanModuleIncludeSelf", + "ScatterReduceIntMinModule", + "ScatterReduceIntProdModule", + "ScatterReduceIntSumModule", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", + "SliceEndSleStartModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceStartEqEndModule_basic", + "SoftmaxBackwardModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "StdCorrectionEmptyDimModule_basic", + "StdDimEmptyDimModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TanhBackward_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dFloatModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dFloatModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "ToCopyBoolDTypeStaticModule_basic", + "ToCopyModule_basic", + "ToCopyWithDTypeFalsePinMemoryModule_basic", + "ToCopyWithDTypeModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "TraceModule_basic", + "TraceModule_empty", + "TraceModule_nonsquare", + "TraceSignedIntModule_basic", + "TraceUnsignedIntModule_basic", + "TraceUnsignedIntModule_empty", + "UniformModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "UnsafeView1DFoldModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", + "UnsafeViewDynamicExpandWithAtenSizeIntModule_basic", + "UnsafeViewExpandModule_basic", + "UpSampleNearest2dBackwardScalesNone_basic", + "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2d_basic", + "VarCorrectionEmptyDimModule_basic", + "VarDimEmptyDimModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewCollapseModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandCollapseWithAtenIntModule_basic", + "ViewDynamicExpandModule_basic", + "ViewDynamicExpandWithAtenSizeIntModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "_Convolution2DAllFalseModule_basic", + "_Convolution2DBenchmarkModule_basic", + "_Convolution2DCudnnModule_basic", + "_Convolution2DDeterministicModule_basic", + "_Convolution2DTF32Module_basic", + "_ConvolutionDeprecated2DAllFalseModule_basic", + "_ConvolutionDeprecated2DBenchmarkModule_basic", + "_ConvolutionDeprecated2DCudnnModule_basic", + "_ConvolutionDeprecated2DDeterministicModule_basic", + "_SoftmaxModule_basic", + + # Failure - onnx_import + "DiagonalModule_basic", + "DiagonalModule_nonsquare", + "DiagonalModule_transposed", + "DiagonalModule_with_dims", + "DiagonalModule_with_dims_and_offset", + "DiagonalModule_with_negative_dims", + "DiagonalModule_with_offset", + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModuleIncludeSelf", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "LinalgNormKeepDimModule_basic", + "LinalgNormModule_basic", + + # Failure - onnx_lowering: onnx.AveragePool + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AvgPool1dFloatModule_basic", + "AvgPool1dIntModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + + # Failure - onnx_lowering: onnx.Cast + "BucketizeTensorOutInt32RightModule_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "HBC_basic", + "QuantizedMLP_basic", + "TypeConversionI1ToI32Module_basic", + "TypeConversionI64ToI32Module_basic", + + # Failure - onnx_lowering: onnx.Clip + "NormalizeModule_basic", + + # Failure - onnx_lowering: onnx.Einsum + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + + # Failure - onnx_lowering: onnx.HardSwish + "HardswishModule_basic", + "HardswishRandomModule_basic", + "MobilenetV3Module_basic", + + # Failure - onnx_lowering: onnx.LogSoftmax + "LogSoftmaxIntModule_basic", + "_LogSoftmaxModuleStable_basic", + "_LogSoftmaxModule_basic", + + # Failure - onnx_lowering: onnx.MaxPool + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + + # Failure - onnx_lowering: onnx.Mod + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + + # Failure - onnx_lowering: onnx.OneHot + "OneHotModule_basic", + + # Failure - onnx_lowering: onnx.Pad + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", + + # Failure - onnx_lowering: onnx.RandomNormal + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnModule_basic", + + # Failure - onnx_lowering: onnx.RandomNormalLike + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + + # Failure - onnx_lowering: onnx.RandomUniform + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + + # Failure - onnx_lowering: onnx.RandomUniformLike + "BernoulliFloatModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RandModule_basic", + + # Failure - onnx_lowering: onnx.ReduceL1 + "ReduceL1NormModule_basic", + "ReduceL1NormWithDTypeModule_basic", + + # Failure - onnx_lowering: onnx.ReduceL2 + "ReduceL2NormModule_basic", + + # Failure - onnx_lowering: onnx.ReduceProd + "BernoulliModule_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "ReduceProdDimIntFloatModule_basic", + "StdCorrectionLargeInputModule_basic", + "VarCorrectionLargeInputModule_basic", + + # Failure - onnx_lowering: onnx.ReduceSum + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + + # Failure - onnx_lowering: onnx.Resize + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticSize_basic", + + # Failure - onnx_lowering: onnx.ScatterElements + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + + # Failure - onnx_lowering: onnx.ScatterND + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + + # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + + # Failure - onnx_lowering: onnx.Softplus + "ElementwiseMishModule_basic", + "SoftplusModule_basic", + + # Failure - onnx_lowering: onnx.Squeeze + "SqueezeModule_allUnitDim", + "SqueezeModule_broadcast", + "SqueezeModule_static", + + # Failure - onnx_lowering: onnx.TopK + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + + # Failure - onnx_lowering: onnx.Trilu + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "AtenTriuModule_basic", + "AtenTriuWithNegDiagonalModule_basic", + "AtenTriuWithPosDiagonalModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + + # Failure - incorrect dtype + "ReduceMaxAlongDimUnsignedInt_basic", + + # Failure - torch.aten.view lower + "IndexTensorDyanmicInputContiguousWithNoneModule_basic", + "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", + "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputContiguousCenter_basic", + "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputNonContiguous_basic", + "IndexTensorMultiInputOneDim_basic", + "IndexTensorMultiInputThreeIndexers_basic", + "IndexTensorMultiInput_basic", + "ViewFlattenAndExpandModule_basic", + "ViewSizeDimFollowedByCollapsedOnesModule_basic", + "ViewSizeDimFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", + "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedByCollapsedOnesModule_basic", + "ViewSizeDimLedByExpandedOnesModule_basic", + + # Failure - unknown + "BucketizeTensorFloatModule_basic", + "BucketizeTensorModule_basic", + "BucketizeTensorStaticFloatModule_basic", + "BucketizeTensorStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "CopyWithDifferentDTypesAndSizesModule_basic", + "CopyWithDifferentDTypesModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CumsumInputDtypeInt32Module_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseDivRoundingModeTruncModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwisePreluModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseUnaryIntModule_basic", + "ElementwiseUnsqueezeNegDimsModule_basic", + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", + "FlattenDynamicModule_basic", + "GluStaticModule_basic", + "GroupNormModule_basic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorSelectDimModule_basic", + "MaskedFillTensorFloatValueModule_basic", + "ReduceAllDimEmpty_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "TensorsStackNegativeDimModule_basic", + "TensorsStackPromoteDTypeModule_basic", + + # Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1" + "AtenLinalgCrossDynamic_basic", + + # Only on feature/backport_ea1_ops + "AtenToDtypeModule_basic", + "Conv1dNoPaddingGroupModule_basic", + "ElementwiseAcosTensorIntModule_basic", + "ElementwiseAsinTensorIntModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "Im2ColModule_basic", + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + "PrimsSumFloatModule_basic", + "RepeatInterleaveFillModule_basic", + "RepeatInterleaveModule_basic", + "RepeatInterleaveStaticModule_basic", + "SliceCopyMax_Module_basic", +} + +ONNX_CRASHING_SET = { } + diff --git a/projects/pt1/examples/torchdynamo_resnet18.py b/projects/pt1/examples/torchdynamo_resnet18.py index d7abd80da665..377c632da36f 100644 --- a/projects/pt1/examples/torchdynamo_resnet18.py +++ b/projects/pt1/examples/torchdynamo_resnet18.py @@ -14,7 +14,7 @@ import torchvision.models as models from torchvision import transforms -import torch_mlir +from torch_mlir import torchscript from torch_mlir.dynamo import make_simple_dynamo_backend from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend @@ -71,7 +71,7 @@ def predictions(torch_func, jit_func, img, labels): @make_simple_dynamo_backend def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - mlir_module = torch_mlir.compile( + mlir_module = torchscript.compile( fx_graph, example_inputs, output_type="linalg-on-tensors") backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(mlir_module) diff --git a/projects/pt1/examples/torchscript_resnet18.py b/projects/pt1/examples/torchscript_resnet18.py index ac46e6f4523b..62e5eda6cc83 100644 --- a/projects/pt1/examples/torchscript_resnet18.py +++ b/projects/pt1/examples/torchscript_resnet18.py @@ -12,7 +12,7 @@ import torchvision.models as models from torchvision import transforms -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend @@ -67,7 +67,7 @@ def predictions(torch_func, jit_func, img, labels): resnet18 = models.resnet18(pretrained=True) resnet18.train(False) -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(module) jit_module = backend.load(compiled) diff --git a/projects/pt1/examples/torchscript_resnet18_all_output_types.py b/projects/pt1/examples/torchscript_resnet18_all_output_types.py index a17fa40521d3..70a920550b2d 100644 --- a/projects/pt1/examples/torchscript_resnet18_all_output_types.py +++ b/projects/pt1/examples/torchscript_resnet18_all_output_types.py @@ -6,15 +6,15 @@ import torch import torchvision -import torch_mlir +from torch_mlir import torchscript resnet18 = torchvision.models.resnet18(pretrained=True) resnet18.eval() -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10)) -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10)) # TODO: Debug why this is so slow. -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10)) diff --git a/projects/pt1/examples/torchscript_resnet_inference.ipynb b/projects/pt1/examples/torchscript_resnet_inference.ipynb index 3ab7cc64dadb..9970f90b8bb2 100644 --- a/projects/pt1/examples/torchscript_resnet_inference.ipynb +++ b/projects/pt1/examples/torchscript_resnet_inference.ipynb @@ -184,7 +184,7 @@ "\n", "# Compile the model with an example input.\n", "# We lower to the linalg-on-tensors form that the reference backend supports.\n", - "compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", + "compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", "# Load it on the reference backend.\n", "jit_module = compile_and_load_on_refbackend(compiled)\n", "# Run it!\n", @@ -326,7 +326,7 @@ "source": [ "resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", "resnet18.eval()\n", - "compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n", + "compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n", "jit_module = compile_and_load_on_refbackend(compiled)" ] }, diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py index 7a97359cff62..e42828ed776e 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py @@ -1,13 +1,13 @@ import torch import torchvision.models as models -import torch_mlir +from torch_mlir import torchscript model = models.resnet18(pretrained=True) model.eval() data = torch.randn(2,3,200,200) out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False) +module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index c035be3a54fe..c68daf12dd86 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -1,5 +1,5 @@ import torch -import torch_mlir +from torch_mlir import torchscript from transformers import BertForMaskedLM @@ -17,7 +17,7 @@ def forward(self, data): data = torch.randint(30522, (2, 128)) out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True) +module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index 6ed43a7317c8..eedbf83c6e46 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -7,81 +7,24 @@ set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON) # argument. set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") - # We vendor our own MLIR instance in the `torch_mlir` namespace. add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") -################################################################################ -# PyTorch -################################################################################ - -if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) - # Source builds - set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) - set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH}) - set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD}) - set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) - set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) - set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) - set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) - execute_process( - COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh - RESULT_VARIABLE _result - ) - if(_result) - message(FATAL_ERROR "Failed to run `build_libtorch.sh`") - endif() - set(TORCH_INSTALL_PREFIX "libtorch") -endif() - -################################################################################ -# Sources -################################################################################ - -declare_mlir_python_sources(TorchMLIRPythonSources) -declare_mlir_python_sources(TorchMLIRPythonExtensions) - -if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources - SOURCES - __init__.py - repro.py - fx_minifier.py - _dynamo_fx_importer.py - compiler_utils.py - dynamo.py - _version.py - ) -endif() - -declare_mlir_python_sources(TorchMLIRPythonSources.Dialects - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources -) +# ################################################################################ +# # Sources +# ################################################################################ -declare_mlir_dialect_python_bindings( - ADD_TO_PARENT TorchMLIRPythonSources.Dialects +declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - TD_FILE dialects/TorchBinding.td - SOURCES dialects/torch/__init__.py - DIALECT_NAME torch -) - -################################################################################ -# Extensions -################################################################################ - -declare_mlir_python_extension(TorchMLIRPythonExtensions.Main - MODULE_NAME _torchMlir - ADD_TO_PARENT TorchMLIRPythonExtensions + ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES - TorchMLIRModule.cpp - EMBED_CAPI_LINK_LIBS - TorchMLIRCAPI - PRIVATE_LINK_LIBS - LLVMSupport + torchscript.py + _dynamo_fx_importer.py + compiler_utils.py + dynamo.py + repro.py + fx_minifier.py + _version.py ) ################################################################################ @@ -112,56 +55,23 @@ endif() # add_subdirectory(torch_mlir/_torch_mlir_custom_op_example) -################################################################################ -# Generate packages and shared library -# Downstreams typically will not use these, but they are useful for local -# testing. -################################################################################ - -set(_source_components - # TODO: Core is now implicitly building/registering all dialects, increasing - # build burden by ~5x. Make it stop. - # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes - # for the reference backend, but logically they can be separate. But seemingly - # the only way to handle that is to create a separate mlir python package - # tree, which seems excessive. - MLIRPythonSources - MLIRPythonExtension.Core - MLIRPythonExtension.RegisterEverything - TorchMLIRPythonSources - TorchMLIRPythonExtensions -) - -add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI - INSTALL_COMPONENT TorchMLIRPythonModules - INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs - OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" - RELATIVE_INSTALL_ROOT "../../../.." - DECLARED_SOURCES ${_source_components} -) - -add_mlir_python_modules(TorchMLIRPythonModules - ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir" - INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir" - DECLARED_SOURCES ${_source_components} - COMMON_CAPI_LINK_LIBS - TorchMLIRAggregateCAPI - ) - # TODO: Find a cleaner way to do this. # Can we build the JIT IR importer with `declare_mlir_python_extension`? # Then it would "just work". if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) - add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporter) - add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporterPybind) - # Build the E2E Tests (which depend on the JIT IR importer now). - add_dependencies(TorchMLIRPythonModules TorchMLIRE2ETestPythonModules) + add_dependencies(TorchMLIRPythonTorchExtensionsSources + TorchMLIRJITIRImporter + TorchMLIRJITIRImporterPybind + TorchMLIRE2ETestPythonModules + ) endif() if(TORCH_MLIR_ENABLE_LTC) # Add Torch-MLIR LTC backend as dependency - add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) - add_dependencies(TorchMLIRPythonModules reference_lazy_backend) + add_dependencies(TorchMLIRPythonTorchExtensionsSources + torch_mlir_ltc_backend + reference_lazy_backend + ) endif() add_subdirectory(test) diff --git a/projects/pt1/python/test/compile_api/already_scripted.py b/projects/pt1/python/test/compile_api/already_scripted.py index 367170081228..7d9720727a38 100644 --- a/projects/pt1/python/test/compile_api/already_scripted.py +++ b/projects/pt1/python/test/compile_api/already_scripted.py @@ -6,7 +6,7 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class BasicModule(torch.nn.Module): @@ -15,17 +15,17 @@ def sin(self, x): return torch.ops.aten.sin(x) -example_args = torch_mlir.ExampleArgs() +example_args = torchscript.ExampleArgs() example_args.add_method("sin", torch.ones(2, 3)) scripted = torch.jit.script(BasicModule()) -print(torch_mlir.compile(scripted, example_args)) +print(torchscript.compile(scripted, example_args)) # CHECK: module # CHECK-DAG: func.func @sin scripted = torch.jit.script(BasicModule()) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) + torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/already_traced.py b/projects/pt1/python/test/compile_api/already_traced.py index a719eb743c73..32f7b5653fca 100644 --- a/projects/pt1/python/test/compile_api/already_traced.py +++ b/projects/pt1/python/test/compile_api/already_traced.py @@ -6,23 +6,23 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class BasicModule(torch.nn.Module): def forward(self, x): return torch.ops.aten.sin(x) example_arg = torch.ones(2, 3) -example_args = torch_mlir.ExampleArgs.get(example_arg) +example_args = torchscript.ExampleArgs.get(example_arg) traced = torch.jit.trace(BasicModule(), example_arg) -print(torch_mlir.compile(traced, example_args)) +print(torchscript.compile(traced, example_args)) # CHECK: module # CHECK-DAG: func.func @forward traced = torch.jit.trace(BasicModule(), example_arg) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg)) + torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg)) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/backend_legal_ops.py b/projects/pt1/python/test/compile_api/backend_legal_ops.py index 98c034930243..64ebf7a522fa 100644 --- a/projects/pt1/python/test/compile_api/backend_legal_ops.py +++ b/projects/pt1/python/test/compile_api/backend_legal_ops.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class AddmmModule(torch.nn.Module): def __init__(self): @@ -15,9 +15,9 @@ def __init__(self): def forward(self, x, y, z): return torch.ops.aten.addmm(x, y, z) -example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)] +example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)] -print(torch_mlir.compile(AddmmModule(), example_args, +print(torchscript.compile(AddmmModule(), example_args, output_type="torch", backend_legal_ops=["aten.addmm"])) # CHECK-LABEL: @forward # CHECK: torch.aten.addmm diff --git a/projects/pt1/python/test/compile_api/basic.py b/projects/pt1/python/test/compile_api/basic.py index 999d2fe4a820..0c516b620863 100644 --- a/projects/pt1/python/test/compile_api/basic.py +++ b/projects/pt1/python/test/compile_api/basic.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): def __init__(self): @@ -18,24 +18,24 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. -print(torch_mlir.compile(TanhModule(), tanh_example_input)) +print(torchscript.compile(TanhModule(), tanh_example_input)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Use a TensorPlaceholder to represent dynamic axes. -placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) -print(torch_mlir.compile(TanhModule(), placeholder)) +placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) +print(torchscript.compile(TanhModule(), placeholder)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> # Explicitly construct a TensorPlaceholder. -placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32) -print(torch_mlir.compile(TanhModule(), placeholder)) +placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32) +print(torchscript.compile(TanhModule(), placeholder)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32> # Basic smoke test for the raw output type. -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW)) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW)) # CHECK: torch.nn_module { # CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule"> @@ -47,12 +47,12 @@ def forward(self, lhs, rhs ): # N > 1 inputs. mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)] -print(torch_mlir.compile(MmModule(), mm_example_inputs)) +print(torchscript.compile(MmModule(), mm_example_inputs)) # CHECK-LABEL: @forward # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32> # Mixes Tensor's and TensorPlaceholder's. -mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] -print(torch_mlir.compile(MmModule(), mm_dynamic_inputs)) +mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] +print(torchscript.compile(MmModule(), mm_dynamic_inputs)) # CHECK-LABEL: @forward # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32> diff --git a/projects/pt1/python/test/compile_api/do_test.py b/projects/pt1/python/test/compile_api/do_test.py index 7e5e4e245604..ccf127d71ba8 100644 --- a/projects/pt1/python/test/compile_api/do_test.py +++ b/projects/pt1/python/test/compile_api/do_test.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Optional -import torch_mlir +from torch_mlir.torchscript import do import torch class Model(torch.nn.Module): @@ -28,12 +28,12 @@ def forward(self, x): return ModelOutput(x=2 * x, y=x+x) -torch_mlir.do(Model(), torch.ones(5), output_type="torch") -torch_mlir.do(ModelWithTuple(), torch.ones(5), output_type="torch") -torch_mlir.do(ModelWithNestedTuple(), torch.ones(5), output_type="torch") -torch_mlir.do(ModelWithDataclassOutput(), torch.ones(5), output_type="torch") +do(Model(), torch.ones(5), output_type="torch") +do(ModelWithTuple(), torch.ones(5), output_type="torch") +do(ModelWithNestedTuple(), torch.ones(5), output_type="torch") +do(ModelWithDataclassOutput(), torch.ones(5), output_type="torch") -torch_mlir.do(Model(), torch.ones(5), output_type="tosa") -torch_mlir.do(Model(), torch.ones(5), output_type="tosa", dtype=torch.bfloat16) -torch_mlir.do(Model(), torch.ones(5), output_type="tosa", dtype=torch.bfloat16, output_prefix="out") +do(Model(), torch.ones(5), output_type="tosa") +do(Model(), torch.ones(5), output_type="tosa", dtype=torch.bfloat16) +do(Model(), torch.ones(5), output_type="tosa", dtype=torch.bfloat16, output_prefix="out") diff --git a/projects/pt1/python/test/compile_api/make_fx.py b/projects/pt1/python/test/compile_api/make_fx.py index 62add20a576b..ec859d86e369 100644 --- a/projects/pt1/python/test/compile_api/make_fx.py +++ b/projects/pt1/python/test/compile_api/make_fx.py @@ -8,7 +8,7 @@ import functorch import torch -import torch_mlir +from torch_mlir import torchscript def simple(x): return x * x @@ -17,6 +17,6 @@ def simple(x): graph = functorch.make_fx(simple)(torch.randn(1,)) # Simplest case: One example argument. -print(torch_mlir.compile(graph, example_input)) +print(torchscript.compile(graph, example_input)) # CHECK-LABEL: @forward # CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> \ No newline at end of file diff --git a/projects/pt1/python/test/compile_api/multiple_methods.py b/projects/pt1/python/test/compile_api/multiple_methods.py index f70b14ab68ab..067e775bfc71 100644 --- a/projects/pt1/python/test/compile_api/multiple_methods.py +++ b/projects/pt1/python/test/compile_api/multiple_methods.py @@ -6,7 +6,7 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class TwoMethodsModule(torch.nn.Module): @@ -17,14 +17,14 @@ def cos(self, x): return torch.ops.aten.cos(x) -example_args = torch_mlir.ExampleArgs() +example_args = torchscript.ExampleArgs() example_args.add_method("sin", torch.ones(2, 3)) example_args.add_method("cos", torch.ones(2, 4)) # Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to # check the `use_tracing` case first. -print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True)) +print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True)) # CHECK: module # CHECK-DAG: func.func @sin # CHECK-DAG: func.func @cos @@ -34,8 +34,8 @@ def cos(self, x): # Otherwise the user would have to do this manually, which is tedious. This # technically mutates the user input model which is not great but probably okay # for this kind of API sugar. Users can always take full control of the process -# by scripting the model themselves before passing it to `torch_mlir.compile`. -print(torch_mlir.compile(TwoMethodsModule(), example_args)) +# by scripting the model themselves before passing it to `torchscript.compile`. +print(torchscript.compile(TwoMethodsModule(), example_args)) # CHECK: module # CHECK-DAG: func.func @sin # CHECK-DAG: func.func @cos diff --git a/projects/pt1/python/test/compile_api/output_type_spec.py b/projects/pt1/python/test/compile_api/output_type_spec.py index b975c2b5c0ae..92ed1e425d8d 100644 --- a/projects/pt1/python/test/compile_api/output_type_spec.py +++ b/projects/pt1/python/test/compile_api/output_type_spec.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): def __init__(self): @@ -17,9 +17,9 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.TORCH)) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type="torch")) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch")) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> diff --git a/projects/pt1/python/test/compile_api/tracing.py b/projects/pt1/python/test/compile_api/tracing.py index ea74fea12ab4..bbf652f07a28 100644 --- a/projects/pt1/python/test/compile_api/tracing.py +++ b/projects/pt1/python/test/compile_api/tracing.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): @@ -17,38 +17,38 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. -print(torch_mlir.compile(TanhModule(), tanh_example_input, use_tracing=True)) +print(torchscript.compile(TanhModule(), tanh_example_input, use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Simplest case: Passed as a tuple. -print(torch_mlir.compile(TanhModule(), (tanh_example_input,), use_tracing=True)) +print(torchscript.compile(TanhModule(), (tanh_example_input,), use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Simplest case: Passed as a list. -print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True)) +print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # TensorPlaceholder support. -placeholder = torch_mlir.TensorPlaceholder.like( +placeholder = torchscript.TensorPlaceholder.like( tanh_example_input, dynamic_axes=[1]) -print(torch_mlir.compile(TanhModule(), [placeholder], +print(torchscript.compile(TanhModule(), [placeholder], use_tracing=True, ignore_traced_shapes=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> try: # CHECK: `ignore_traced_shapes` requires `use_tracing` - torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) + torchscript.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) except Exception as e: print(e) try: # CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True` - torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True) + torchscript.compile(TanhModule(), [placeholder], use_tracing=True) except Exception as e: print(e) @@ -60,13 +60,13 @@ def forward(self, x): try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) + torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) except Exception as e: print(e) try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) + torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) except Exception as e: print(e) diff --git a/projects/pt1/python/test/dynamo_fx_importer/basic.py b/projects/pt1/python/test/dynamo_fx_importer/basic.py index cea2f639f01d..fd3dcc7f4c2d 100644 --- a/projects/pt1/python/test/dynamo_fx_importer/basic.py +++ b/projects/pt1/python/test/dynamo_fx_importer/basic.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true from typing import List diff --git a/projects/pt1/python/test/torchscript_e2e_test/basic.py b/projects/pt1/python/test/torchscript_e2e_test/basic.py index fa3f6f29729b..2dcface6f4e8 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/basic.py +++ b/projects/pt1/python/test/torchscript_e2e_test/basic.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch diff --git a/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py b/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py index 9b9091452f01..36d81d83ab04 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py +++ b/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch diff --git a/projects/pt1/python/test/torchscript_e2e_test/error_reports.py b/projects/pt1/python/test/torchscript_e2e_test/error_reports.py index f3321285999a..1ebc3dd6dd42 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/error_reports.py +++ b/projects/pt1/python/test/torchscript_e2e_test/error_reports.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true from typing import List, Tuple, Dict diff --git a/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py b/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py index a1c8c5adfdf4..899dae0c1b9f 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py +++ b/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true from typing import List, Tuple, Dict diff --git a/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py b/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py index 3581c1b6d41f..a5cc12e66857 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py +++ b/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch diff --git a/projects/pt1/python/test/torchscript_e2e_test/submodule.py b/projects/pt1/python/test/torchscript_e2e_test/submodule.py index c88ad53b31b3..8fc520c94396 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/submodule.py +++ b/projects/pt1/python/test/torchscript_e2e_test/submodule.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch diff --git a/projects/pt1/python/torch_mlir/compiler_utils.py b/projects/pt1/python/torch_mlir/compiler_utils.py index 310ad6b73731..3351f4eddc93 100644 --- a/projects/pt1/python/torch_mlir/compiler_utils.py +++ b/projects/pt1/python/torch_mlir/compiler_utils.py @@ -31,7 +31,8 @@ class TorchMlirCompilerError(Exception): def run_pipeline_with_repro_report(module, pipeline: str, - description: str): + description: str, + enable_ir_printing: bool = False): """Runs `pipeline` on `module`, with a nice repro report if it fails.""" module_name = get_module_name_for_debug_dump(module) try: @@ -40,8 +41,11 @@ def run_pipeline_with_repro_report(module, asm_for_error_report = module.operation.get_asm( large_elements_limit=10, enable_debug_info=True) # Lower module in place to make it ready for compiler backends. - with module.context: + with module.context as ctx: pm = PassManager.parse(pipeline) + if enable_ir_printing: + ctx.enable_multithreading(False) + pm.enable_ir_printing() pm.run(module.operation) except Exception as e: # TODO: More robust. @@ -64,7 +68,7 @@ def run_pipeline_with_repro_report(module, {sys.stderr.getvalue()} python exception: {e} - + For Torch-MLIR developers, the error can be reproduced with: $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} Add '{debug_options}' to get the IR dump for debugging purpose. diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 4bcb9347b5aa..8708ff06a5a2 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -30,7 +30,7 @@ namespace lazy { /// Returns true if a string begins with another. inline bool beginswith(const std::string& s, const std::string& t) { - return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; + return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; } struct ReferenceLazyBackendDeviceType : public BackendDeviceType { @@ -73,10 +73,8 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { // Vendor backend specific lowering can be exec here before returning. for (const auto& instance : instances) { TORCH_CHECK( - instance->in_mark_step, - "Compile outside of mark step:\n", - GetComputationBackendText(instance) - ); + instance->in_mark_step, "Compile outside of mark step:\n", + GetComputationBackendText(instance)); // Store computation instance for external access after compilation. GetLatestComputation() = instance; } @@ -114,16 +112,17 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { // Convert any lazy devices to cpu devices to ensure // that the values are actually computed if (node->outputs().size() == 1 && - node->output()->type()->kind() == - c10::TypeKind::DeviceObjType) { - auto value_sym = torch::jit::Symbol::attr("value"); - TORCH_CHECK(node->hasAttribute(value_sym), - "Expected node to have 'value' attribute."); - TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s, - "Expected 'value' attribute to be a string."); - if (beginswith(node->s(value_sym), "lazy")) { - node->s_(value_sym, "cpu"); - } + node->output()->type()->kind() == c10::TypeKind::DeviceObjType) { + auto value_sym = torch::jit::Symbol::attr("value"); + TORCH_CHECK( + node->hasAttribute(value_sym), + "Expected node to have 'value' attribute."); + TORCH_CHECK( + node->kindOf(value_sym) == torch::jit::AttributeKind::s, + "Expected 'value' attribute to be a string."); + if (beginswith(node->s(value_sym), "lazy")) { + node->s_(value_sym, "cpu"); + } } } @@ -132,7 +131,8 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { for (const auto& argument : arguments) { const auto mlir_data = std::static_pointer_cast(argument); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto* info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info); if (info->scalar.has_value()) { stack.emplace_back(info->scalar.value()); diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp index f4b8cd9ba579..2cbb6d6f16dc 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -8,8 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch/csrc/jit/python/pybind.h" -#include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/backend/backend_interface.h" +#include "torch/csrc/lazy/core/config.h" #include #include @@ -56,8 +56,8 @@ void Initialize() { } if (ir_debug) { - FLAGS_torch_lazy_ir_debug = true; - std::cout << "Enabled lazy tensor IR debugging." << std::endl; + FLAGS_torch_lazy_ir_debug = true; + std::cout << "Enabled lazy tensor IR debugging." << std::endl; } } @@ -82,15 +82,17 @@ PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) { torch::lazy::GetLatestComputation().get()); return py::cast(computation); }); - m.def("set_parameter_name", - [](const at::Tensor& tensor, const std::string& name) -> bool { - torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor); - if (ir_node) { - ir_node->SetName(name); - return true; - } - return false; - }); + m.def( + "set_parameter_name", + [](const at::Tensor& tensor, const std::string& name) -> bool { + torch::lazy::DeviceData* ir_node = + torch::lazy::device_data_cast(tensor); + if (ir_node) { + ir_node->SetName(name); + return true; + } + return false; + }); m.def("_initialize", []() { NoGilSection gil; Initialize(); diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index cb15018d6887..b9420f1f8d34 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -19,12 +19,12 @@ def _get_decomposition_table(): """Get a decomposition table suitable for Torch-MLIR. - + Sometimes TorchDynamo traces slightly different ops than what TorchScript captures. Historically we have been driven by the ops captured by TorchScript, so we try to decompose the ops captured by TorchDynamo into other ops that we already support. - + There isn't a highly principled solution here. Torch-MLIR currently supports a somewhat random set of ops, added in a demand-driven way over time, including direct backend support and decompositions internal to Torch-MLIR. @@ -130,7 +130,7 @@ def make_simple_dynamo_backend(user_backend): Args: user_backend: A function with the signature used by ordinary TorchDynamo backends. But the torch.fx.GraphModule passed to it - will be normalized for consumption by `torch_mlir.compile`. + will be normalized for consumption by `torchscript.compile`. Returns: A function with the signature used by TorchDynamo backends. """ diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt index c2883b3dca84..6c2ccf62eb78 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt @@ -4,9 +4,9 @@ ## Declare the sources of the Python module. -declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter +declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources.JitIRImporter ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources + ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES_GLOB jit_ir_importer/*.py ) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index c18817070a2d..31ce183bb7a0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -59,12 +59,69 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇atan〡shape(self: List[int]) -> List[int]: +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`. + Invocation(TensorOfShape(2, 3, 4), dim1=-1, dim2=-2, offset=1), # Positive `offset`. + Invocation(TensorOfShape(2, 3, 4), offset=-1), # Negative `offset``. + Invocation(TensorOfShape(2, 3, 4), offset=3), # Empty result due to large `offset`. + ErrorInvocation(TensorOfShape(2)), # Input one-dimensional. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=3, dim2=1), # `dim1` out of bounds. +]) +def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> List[int]: + assert len(self) >= 2, "input must have at least two dimensions" + dim1 = upstream_shape_functions.maybe_wrap_dim(dim1, len(self)) + dim2 = upstream_shape_functions.maybe_wrap_dim(dim2, len(self)) + assert dim1 != dim2, "diagonal dimensions cannot be identical" + + diagonal: List[int] = [] + for i, self_dim in enumerate(self): + if (i==dim1) or (i==dim2): + pass + else: + diagonal.append(self_dim) + + diag_size = max(min(self[dim1], self[dim2] - offset), 0) + if offset<0: + diag_size = max(min(self[dim1] + offset, self[dim2]), 0) + diagonal.append(diag_size) + + return diagonal + +def aten〇sin〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇asin〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇asinh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇cos〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇cosh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇acos〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇acosh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇tan〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) def aten〇tanh〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇atan〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇atanh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇erf〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -92,18 +149,6 @@ def aten〇exp〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇sin〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - -def aten〇cos〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - -def aten〇asin〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - -def aten〇acos〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇cosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]: broadcast = upstream_shape_functions.broadcast(x1, x2) return broadcast[:dim] + broadcast[dim + 1:] @@ -138,6 +183,9 @@ def aten〇log10〡shape(self: List[int]) -> List[int]: def aten〇log1p〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇logit〡shape(self: List[int], eps: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇rsqrt〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -215,22 +263,51 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]: def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇dequantize〇self〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]: + return upstream_shape_functions.unary(qtensor) + +def aten〇int_repr〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇_make_per_channel_quantized_tensor〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]: + return upstream_shape_functions.unary(self) + def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]: return upstream_shape_functions.unary(a) +def aten〇grid_sampler〡shape(input: List[int], grid: List[int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> List[int]: + output = [input[0],input[1],grid[1],grid[2]] + return output + def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: # Obtained through trial and error on a few examples in PyTorch: - assert start <= len(a), "start out of bounds" - assert end <= len(a), "end out of bounds" + assert start < len(a), "start out of bounds" + assert end < len(a), "end out of bounds" assert start >= 0, "start out of bounds" assert end >= 0, "end out of bounds" assert start <= end, "start must be less than or equal to end" - # Example: + # Examples: # # torch._prims.collapse(torch.empty(2,3,4), 1,2).shape - # is + # is # torch.Size([2, 12]) + # + # torch._prims.collapse(torch.empty(2,3,4), 1,3).shape + # gives + # --> 524 assert idx >= 0 and idx < rank or idx == 0 collapsed: List[int] = [] for i in range(start): @@ -307,6 +384,17 @@ def aten〇clone〡shape(self: List[int], memory_format: Optional[int] = None) - def aten〇lift_fresh_copy〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +@check_shape_function([ + Invocation(TensorOfShape(1, 2, 3), TensorOfShape(4, 1, 3)), # two dimensions to broadcast, self[0] and other[1] + ErrorInvocation(TensorOfShape(3), TensorOfShape(2, 3)), # different number of dimensions + ErrorInvocation(TensorOfShape(2, 3), TensorOfShape(4, 3)) # non-broadcastable dimensions +]) +def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1) -> List[int]: + assert len(self) == len(other), "inputs must have the same number of dimensions" + for i in range(len(self)): + assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}" + return upstream_shape_functions.broadcast(self, other) + def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -316,6 +404,12 @@ def aten〇isnan〡shape(self: List[int]) -> List[int]: def aten〇isinf〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇isneginf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇isposinf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -352,6 +446,12 @@ def aten〇div〇Scalar〡shape(self: List[int], other: float) -> List[int]: def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇remainder〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇fmod〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇floor_divide〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) @@ -376,6 +476,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇selu〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_grad: bool = False) -> List[int]: return upstream_shape_functions.unary(index) @@ -445,6 +548,15 @@ def aten〇std〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +@check_shape_function([ + Invocation(TensorOfShape(2, 3)), # Basic case. + ErrorInvocation(TensorOfShape(2, 3, 4)), # Too many dimensions. + ErrorInvocation(TensorOfShape(2)), # Too few dimensions. +]) +def aten〇trace〡shape(self: List[int]) -> List[int]: + assert len(self) == 2, "input must have rank 2" + return [] + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. @@ -470,6 +582,9 @@ def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]: def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: return upstream_shape_functions.argmax(self, dim, keepdim) +def aten〇all〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: + return upstream_shape_functions.argmax(self, dim, keepdim) + def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]: reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) return reduced_shape, reduced_shape @@ -498,7 +613,7 @@ def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[i assert len(self) >= 3, "input must be at least rank-3 in pixel_shuffle" upscale_factor_squared = upscale_factor * upscale_factor assert self[-3] % (upscale_factor_squared) == 0, "number of input channels must be divisible by upscale_factor^2 in pixel_shuffle" - + out = self[0:-3] out.append(self[-3] // upscale_factor_squared) out.append(self[-2] * upscale_factor) @@ -618,9 +733,120 @@ def aten〇_unsafe_view〡shape(self: List[int], size: List[int]) -> List[int]: def aten〇resize_〡shape(self: List[int], size: List[int], memory_format: Optional[int] = None) -> List[int]: return size +def _pool3d_shape_check( + input: List[int], + kD: int, + kH: int, + kW: int, + dD: int, + dH: int, + dW: int, + padD: int, + padH: int, + padW: int, + dilationD: int, + dilationH: int, + dilationW: int, + outputDepth: int, + outputHeight: int, + outputWidth: int, +): + ndim = len(input) + + assert kD > 0 and kH > 0 and kW > 0 + assert dD > 0 and dH > 0 and dW > 0 + assert dilationD > 0 and dilationH > 0 and dilationW > 0 + assert ndim == 4 or ndim == 5, "pool3d: input dimensions must be 4 or 5" + if ndim == 4: + assert input[0] != 0 and input[1] != 0 and input[2] != 0 and input[3] != 0 + else: + assert input[0] != 0 and input[1] != 0 and input[2] != 0 and input[3] != 0 and input[4] != 0 + + assert kD // 2 >= padD and kW // 2 >= padW and kH // 2 >= padH + assert outputDepth >= 1 and outputWidth >= 1 and outputHeight >= 1 + +def _max_pool3d( + input: List[int], + kernel_size: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + ceil_mode: bool, +): + assert ( + len(kernel_size) == 1 or len(kernel_size) == 3 + ), "max_pool3d: kernel_size must either be a single int, or a tuple of three ints" + (kD, kH, kW) = (kernel_size[0], kernel_size[0], kernel_size[0]) if len(kernel_size) == 1 else (kernel_size[0], kernel_size[1], kernel_size[2]) + + assert ( + len(stride) == 0 or len(stride) == 1 or len(stride) == 3 + ), "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints" + + if len(stride) == 0: + (dD, dH, dW) = (kD, kD, kD) + elif len(stride) == 1: + (dD, dH, dW) = (stride[0], stride[0], stride[0]) + else: # len(stride) == 3 + (dD, dH, dW) = (stride[0], stride[1], stride[2]) + + assert ( + len(padding) == 1 or len(padding) == 3 + ), "max_pool3d: padding must either be a single int, or a tuple of thee ints" + (padD, padH, padW) = (padding[0], padding[0], padding[0]) if len(padding) == 1 else (padding[0], padding[1], padding[2]) + + assert ( + len(dilation) == 1 or len(dilation) == 3 + ), "max_pool3d: dilation must be either a single int, or a tuple of three ints" + (dilationD, dilationH, dilationW) = (dilation[0], dilation[0], dilation[0]) if len(dilation) == 1 else (dilation[0], dilation[1], dilation[2]) + + assert len(input) == 4 or len(input) == 5 + nbatch = input[-5] if len(input) == 5 else 1 + nInputPlane = input[-4] + inputDepth = input[-3] + inputHeight = input[-2] + inputWidth = input[-1] + + outputDepth = upstream_shape_functions.pooling_output_shape(inputDepth, kD, padD, dD, dilationD, ceil_mode) + outputHeight = upstream_shape_functions.pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = upstream_shape_functions.pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + _pool3d_shape_check( + input, + kD, + kH, + kW, + dD, + dH, + dW, + padD, + padH, + padW, + dilationD, + dilationH, + dilationW, + outputDepth, + outputHeight, + outputWidth, + ) + + if len(input) == 4: + return [nInputPlane, outputDepth, outputHeight, outputWidth] + else: + return [nbatch, nInputPlane, outputDepth, outputHeight, outputWidth] + def aten〇max_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> List[int]: return upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode) +@check_shape_function([ + Invocation(TensorOfShape(3, 6, 10, 10, 10), [2]), # Basic using defaults + Invocation(TensorOfShape(3, 6, 10, 10, 10), [4], [2], [2], [2]), # Using single values for each parameter + Invocation(TensorOfShape(3, 6, 64, 64, 64), [4, 6, 8], [2, 4, 2], [1, 2, 4], [1, 2, 4]), # Using dimensions should be + ErrorInvocation(TensorOfShape(3, 6, 2, 2, 2), [4]), # Input is too small + ErrorInvocation(TensorOfShape(3, 6, 10, 10, 10), [4], [2], [4], [2]), # The following relationship between kernel and padding needs to apply: Kernel size >= 2 * padding size +]) +def aten〇max_pool3d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> List[int]: + return _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode) + def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]: maxpool2d = indices = upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode) return maxpool2d, indices @@ -731,6 +957,24 @@ def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: L def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]: return upstream_shape_functions.adaptive_avg_pool2d(self, output_size) +def adaptive_max_pool2d(self: List[int], out: List[int]): + assert len(out) == 2 + assert len(self) == 3 or len(self) == 4 + + for i in range(len(self)): + assert self[i] != 0 + + shape: List[int] = [] + for i in range(len(self) - 2): + shape.append(self[i]) + for j in range(len(out)): + shape.append(out[j]) + + return shape, shape + +def aten〇adaptive_max_pool2d〡shape(self: List[int], output_size: List[int]) -> Tuple[List[int], List[int]]: + return adaptive_max_pool2d(self, output_size) + def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]: return upstream_shape_functions.flatten(self, start_dim, end_dim) @@ -832,6 +1076,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]: return self +def aten〇exponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]: + return self + def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size @@ -872,6 +1119,9 @@ def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: O def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size +def aten〇normal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]: + return self + def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory) @@ -893,6 +1143,9 @@ def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float def aten〇fake_quantize_per_tensor_affine〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> int: return self_rank_dtype[1] +def aten〇linspace〡shape(start: float, end: float, steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return [steps] + @check_shape_function([ Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case. Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting. @@ -1050,9 +1303,15 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(condition, other) +def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight)) +def aten〇lerp〇Scalar〡shape(self: List[int], end: List[int], weight: float) -> List[int]: + return upstream_shape_functions.broadcast(self, end) + def aten〇addcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2)) @@ -1131,12 +1390,50 @@ def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv3d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: + return upstream_shape_functions.conv3d(input, weight, bias, stride, padding, dilation, groups) + def aten〇conv_transpose2d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> List[int]: return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) +def aten〇conv_tbc〡shape(self: List[int], weight: List[int], bias: List[int], pad: int = 0) -> List[int]: + assert len(self) == 3 # only 1d is supported by tbc + assert len(weight) == 3 + assert len(bias) == 1 + + # tbc -> bct + time = self[0] + batch = self[1] + channels = self[2] + + kernel_width = weight[0] + channels_w = weight[1] + out_channels = weight[2] + + # out_channels_b = bias[0] + + assert channels == channels_w + # the out_channels in weights and biases should also match, but this assert doesn't work because typing problems + # assert out_channels == out_channels_b + + self_bct = [batch, channels, time] + weight_bct = [out_channels, channels, kernel_width] + bias_bct = bias + + # use existing shape inf + output_size_bct = upstream_shape_functions.conv_forwards(self, weight, bias, stride=[1], padding=[pad], dilation=[], transposed=False, output_padding=[], groups=1) + + batch_out, channels_out, time_out = output_size_bct + + # bct -> tbc + return [time_out, batch_out, channels_out] + def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]: return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) +def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]: + return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1) + def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) @@ -1152,6 +1449,15 @@ def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int] def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled) +def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> List[int]: + return upstream_shape_functions.unary(input) + +def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]: + return upstream_shape_functions.unary(input), [N, group], [N, group] + +def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: + return upstream_shape_functions.unary(input) + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -1272,9 +1578,57 @@ def pad_shape_fn(input: List[int], pad: List[int]): def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float = 0) -> List[int]: return pad_shape_fn(self, pad) +def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + assert len(padding) == 4, 'padding size expected to be 4' + return pad_shape_fn(self, padding) + +def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) +#Padding size must be smaller than the size of the last dimension +@check_shape_function([ErrorInvocation(TensorOfShape(1, 2, 4), padding=[4,1]), + Invocation(TensorOfShape(1, 2, 4), padding=[3,3]), + ErrorInvocation(TensorOfShape(1, 2, 4), padding=[1,4]), + ErrorInvocation(TensorOfShape(1, 4), padding=[4,1]), + Invocation(TensorOfShape(1, 4), padding=[3,3]), + ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])]) +def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + hdim = self[-1] + padding_left = padding[0] + padding_right = padding[1] + assert padding_left < hdim and padding_right < hdim + return pad_shape_fn(self, padding) + + +# Padding size must be smaller than corresponding dimension +@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,3]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1]), + Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) +def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + vdim = self[-2] + hdim = self[-1] + + assert len(padding) == 4, 'padding size expected to be 4' + padding_left = padding[0] + padding_right = padding[1] + padding_top = padding[2] + padding_bottom = padding[3] + assert padding_left < hdim and padding_right < hdim + assert padding_top < vdim and padding_bottom < vdim + + return pad_shape_fn(self, padding) + # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" @@ -1401,9 +1755,15 @@ def aten〇nonzero_static〡shape(self: List[int], size: int, fill_value: int = def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) +def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) + def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, None, False, None) + def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) @@ -1551,6 +1911,18 @@ def prims〇split_dim〡dtype(a_rank_dtype: Tuple[int, int], dim: int, outer_len _, a_dtype = a_rank_dtype return a_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇acosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1572,6 +1944,16 @@ def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇asin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇asinh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1622,6 +2004,11 @@ def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1701,12 +2088,34 @@ def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2])) +def aten〇avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype( tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: input_rank, input_dtype = input_rank_dtype return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], error_types={*all_integer_dtypes()}, num_groups=1)) +def aten〇group_norm〡dtype(input_rank_dtype: Tuple[int, int], num_groups: int, weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> int: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7), (3,), (3,)], error_types={*all_integer_dtypes()}, N=2, C=3, HxW=35, group=1, eps=0.000001)) +def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[int, int, int]: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype, input_dtype, input_dtype + +# device is not supported hence unable to check the dtype function +def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype @@ -1803,6 +2212,27 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇grid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dtype: Tuple[int, int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> int: + input_rank, input_dtype = input_rank_dtype + grid_rank, grid_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]), + Invocation(TensorOfShape(2, 3, 4), padding=[2,1]), + Invocation(TensorOfShape(5, 5, 4), padding=[1,2]), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1])]) +def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert len(padding) == 2, 'padding size expected to be 2' + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 2, 2)], padding=[1,1,1,1])) +def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: self_rank, self_dtype = self_rank_dtype @@ -1998,6 +2428,19 @@ def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_device="cpu", tensor_shapes=[(2,3), (2,3)], error_types={torch.bool}) + # same dtype + [ErrorInvocation(TensorOfShape(2, 3, dtype=torch.int32, device="cpu"), TensorOfShape(2, 3, dtype=torch.float16, device="cpu"))] #different dtypes +) +def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], dim: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + assert self_dtype == other_dtype + assert self_dtype != torch.bool + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_two_tensor_op(dim=0, input_dtype=torch.float32) + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) @@ -2032,11 +2475,21 @@ def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2])) +def aten〇max_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) +def aten〇adaptive_max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2268,11 +2721,20 @@ def aten〇tril〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) -> self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim1=0, dim2=1)) +def aten〇diagonal〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇exponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function([Invocation([1]), Invocation([1], dtype=torch.float16), Invocation([1], dtype=torch.complex64)]) @@ -2436,6 +2898,20 @@ def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isneginf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isposinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + @check_dtype_function(_check_two_tensor_op()) def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @@ -2765,6 +3241,26 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) +def aten〇conv1d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +def aten〇conv_tbc〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Tuple[int, int], pad: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert self_dtype == weight_dtype + assert not is_complex_dtype(self_dtype) and self_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool + ranks: List[Optional[int]] = [self_rank, weight_rank] + dtypes = [self_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + _convolution_deprecated_kwargs = { "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} @@ -2811,6 +3307,10 @@ def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv3d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), @@ -2944,6 +3444,27 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp dtypes = [self_dtype, end_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5) + + # Different width + [Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float64), + weight=0.5), + # Different type + Invocation(TensorOfShape(4, 3, dtype=torch.int32), + TensorOfShape(4, 3, dtype=torch.float32), + weight=0.5), + Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32), + weight=2)]) +def aten〇lerp〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + end_rank, end_dtype = end_rank_dtype + + ranks: List[Optional[int]] = [self_rank, end_rank, None] + dtypes = [self_dtype, end_dtype, get_dtype_of_scalar(weight)] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + # Different width @@ -3040,6 +3561,14 @@ def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[ dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇fmod〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) @@ -3088,6 +3617,14 @@ def aten〇elu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float assert not is_integer_dtype(self_dtype) return self_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) +def aten〇selu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) @@ -3097,6 +3634,14 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇remainder〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + # TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) + @@ -3159,6 +3704,12 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel dtypes = [get_dtype_of_scalar(self), other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64), TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), @@ -3351,6 +3902,13 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim return self_dtype return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇all〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.uint8: + return self_dtype + return torch.bool + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3447,6 +4005,44 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return dtype return aten〇std〡dtype(self_rank_dtype) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[Union[int, float, complex]] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if dtype is not None: + assert not is_integer_dtype(dtype) + if is_complex_dtype(self_dtype): + assert is_complex_dtype(dtype) + return aten〇std〡dtype((self_rank, dtype)) + assert not is_complex_dtype(dtype) + return dtype + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) +def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex] = 2) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + # The following check is added because aten〇std〡dtype + # does not handle complex32 transformation to float, + # so it is done manually (torch.half == torch.float16). + # Should possibly be added to aten〇std〡dtype. + if self_dtype == torch.complex32: + return torch.half + return aten〇std〡dtype(self_rank_dtype) + @check_dtype_function([Invocation(0.0), Invocation(0.0, dtype=torch.int32), Invocation(0.0, dtype=torch.float16), @@ -3692,6 +4288,26 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O assert not is_integer_dtype(dtype) return dtype +@check_dtype_function([Invocation(start=1, end=10, steps=9), + Invocation(start=1, end=10, steps=9, dtype=torch.int32), + Invocation(start=1, end=10, steps=9, dtype=torch.double), + Invocation(start=1, end=10, steps=9, dtype=torch.complex64), + Invocation(start=1, end=10, steps=9, dtype=torch.complex128)]) +def aten〇linspace〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) +def aten〇normal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype is None: + return torch.float32 + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function([Invocation(size=[1], generator=None), Invocation(size=[1], generator=None, dtype=torch.float32), ErrorInvocation(size=[1], generator=None, dtype=torch.int32), @@ -3722,6 +4338,13 @@ def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = T return torch.float64, self_dtype return self_dtype, self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇tan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.float32 + return self_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3740,6 +4363,13 @@ def aten〇atan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.float32 + return self_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype @@ -3765,7 +4395,7 @@ def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) return promote_dtypes(ranks, dtypes) @check_dtype_function( - [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), + [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.int32)]),]) def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int]], path: Optional[List[int]] = None) -> int: ranks: List[Optional[int]] = [] @@ -3777,6 +4407,13 @@ def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int dtypes.append(tensor_dtype) return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)])) +def aten〇trace〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.int64 @@ -3878,6 +4515,45 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int return a_dtype +def aten〇quantize_per_channel〡dtype(self_rank_dtype: Tuple[int, int], scales_rank_dtype: Tuple[int, int], zero_points_rank_dtype: Tuple[int, int], axis: int, dtype: int) -> int: + return dtype + +def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int: + return dtype + +def aten〇dequantize〇self〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.float32 + +def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> int: + return torch.float32 + +def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.quint8): + return torch.uint8 + if (self_dtype == torch.qint8): + return torch.int8 + return torch.int32 + +def aten〇_make_per_channel_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.uint8): + return torch.quint8 + if (self_dtype == torch.int8): + return torch.qint8 + return torch.qint32 + +def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.uint8): + return torch.quint8 + if (self_dtype == torch.int8): + return torch.qint8 + return torch.qint32 + + + + # ============================================================================== # Main diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fab101525bd3..a16279c9df78 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -255,15 +255,16 @@ def emit_with_mutating_variants(key, **kwargs): # Elementwise tensor compute ops for key in [ - "aten::tanh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", + "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sign : (Tensor) -> (Tensor)", + "aten::sinh : (Tensor) -> (Tensor)", "aten::sgn : (Tensor) -> (Tensor)", "aten::hardsigmoid : (Tensor) -> (Tensor)", "aten::hardswish : (Tensor) -> (Tensor)", @@ -271,15 +272,20 @@ def emit_with_mutating_variants(key, **kwargs): "aten::erfinv : (Tensor) -> (Tensor)", "aten::silu : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", + "aten::asin : (Tensor) -> (Tensor)", + "aten::asinh : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", + "aten::cosh : (Tensor) -> (Tensor)", "aten::acos : (Tensor) -> (Tensor)", + "aten::acosh : (Tensor) -> (Tensor)", + "aten::tan : (Tensor) -> (Tensor)", + "aten::tanh : (Tensor) -> (Tensor)", "aten::atan : (Tensor) -> (Tensor)", + "aten::atanh : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", - "aten::asin : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", - "aten::ceil : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", @@ -287,19 +293,13 @@ def emit_with_mutating_variants(key, **kwargs): "aten::logical_xor : (Tensor, Tensor) -> (Tensor)", "aten::logical_not : (Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", - "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::le.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", @@ -312,6 +312,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::log10 : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)", + "aten::logit : (Tensor, float?) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", @@ -332,13 +333,22 @@ def emit_with_mutating_variants(key, **kwargs): # Elementwise tensor compute ops that don't have the standard mutating # variants. emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True) + emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) + emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") @@ -375,6 +385,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") + emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)") emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)") emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)") @@ -387,7 +398,6 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") - emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants( "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") emit_with_mutating_variants( @@ -401,12 +411,20 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::mv : (Tensor, Tensor) -> (Tensor)") emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)") + emit( + "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" + ) emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" + ) emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)") + emit("aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)") emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") @@ -419,12 +437,19 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) + emit( + 'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)' + ) emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) emit( "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" ) @@ -486,6 +511,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") @@ -516,6 +542,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") + emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") @@ -534,6 +561,9 @@ def emit_with_mutating_variants(key, **kwargs): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") + emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") + emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") + emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) @@ -543,19 +573,21 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True) - emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") + emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True) emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") + emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)", has_folder=True) emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::isinf : (Tensor) -> (Tensor)") + emit("aten::isneginf : (Tensor) -> (Tensor)") + emit("aten::isposinf : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") @@ -569,8 +601,9 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") + emit("aten::trace : (Tensor) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") - emit("aten::clone : (Tensor, int?) -> (Tensor)") + emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True) emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)") @@ -593,10 +626,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") - emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") + emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") emit("aten::_index_put_impl_.hacked_twin : (Tensor, Tensor[], Tensor, bool, bool) -> (Tensor)") - emit("aten::item : (Tensor) -> (Scalar)") + emit("aten::item : (Tensor) -> (Scalar)", has_folder=True) emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::repeat : (Tensor, int[]) -> (Tensor)") @@ -607,7 +640,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize : (Tensor, int[], int?) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") - emit("aten::select.int : (Tensor, int, int) -> (Tensor)") + emit("aten::select.int : (Tensor, int, int) -> (Tensor)", has_folder=1) emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::sum : (Tensor, int?) -> (Tensor)") emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)") @@ -628,18 +661,19 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") - emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)") - emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") - emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") - emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") + emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True) + emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_folder=True) + emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)", has_folder=True) + emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True) + emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True) emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)") - emit("aten::IntImplicit : (Tensor) -> (int)") - emit("aten::FloatImplicit : (Tensor) -> (float)") + emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True) + emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True) emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) @@ -647,7 +681,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::t : (Tensor) -> (Tensor)") emit("aten::numpy_T : (Tensor) -> (Tensor)") - emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") @@ -655,11 +689,13 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True) emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") + emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") emit("aten::permute_copy : (Tensor, int[]) -> (Tensor)") @@ -683,6 +719,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)") + emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") # Dict ops. emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) @@ -694,7 +731,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()") # List ops. - emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True) + emit("aten::cat : (Tensor[], int) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) @@ -705,9 +742,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::ne.int_list : (int[], int[]) -> (bool)") emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True) emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) - emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") + emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True) emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])") + emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") @@ -718,6 +756,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::str : (t) -> (str)") emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") + emit("aten::warn : (str, int) -> ()") # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) @@ -769,6 +808,7 @@ def emit_with_mutating_variants(key, **kwargs): has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") + emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True) emit("aten::div : (Scalar, Scalar) -> (float)", has_folder=True) emit("aten::add : (Scalar, Scalar) -> (Scalar)", has_folder=True) emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True) @@ -800,6 +840,15 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)") emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + # quantized ops + emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") + emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)") + emit("aten::dequantize.self : (Tensor) -> (Tensor)") + emit("aten::dequantize.tensor : (Tensor) -> (Tensor)") + emit("aten::int_repr : (Tensor) -> (Tensor)") + emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)") + emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") + # ========================================================================== # `prim::` namespace. # ========================================================================== @@ -809,7 +858,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prim::device : (Tensor) -> (Device)", has_canonicalizer=True) emit("prim::dtype : (Tensor) -> (int)", has_folder=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) - emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)") + emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)", has_folder=True) emit("prim::min.self_int : (int[]) -> (int)", has_folder=True) emit("prim::min.int : (int, int) -> (int)", has_folder=True) emit("prim::max.self_int : (int[]) -> (int)") diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/torchscript.py similarity index 96% rename from projects/pt1/python/torch_mlir/__init__.py rename to projects/pt1/python/torch_mlir/torchscript.py index f5a4f4fdf992..33eddb6b1dd8 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -26,7 +26,7 @@ from .compiler_utils import prepare_model, map_kwargs_into_args class OutputType(Enum): - """The kind of output that `torch_mlir.compile` can produce. + """The kind of output that `torchscript.compile` can produce. In MLIR terminology, this describes the mix of dialects that will be produced by the conversion process. @@ -252,7 +252,7 @@ def _get_for_tracing( # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'], - OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ], + OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'], OutputType.STABLEHLO: [], } @@ -323,7 +323,8 @@ def compile(model: torch.nn.Module, backend_legal_ops: Optional[Sequence[str]] = None, extra_library: Iterable[Callable] = [], verbose: bool = False, - use_make_fx: bool = False): + use_make_fx: bool = False, + enable_ir_printing: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -352,7 +353,13 @@ def compile(model: torch.nn.Module, into the abstract interpretation library. See `docs/adding_abstract_interpretation_functions.md` for more info on the format the functions should have. - verbose: If true, print extra information about the conversion. + verbose: If true, print extra information about the conversion to + stdout. + enable_ir_printing: If true, print the IR before and after each pass to + stderr. This is equivalent to setting MLIR's `-print-ir-after-all` + flag. Note that this can easily generate many gigabytes of text, + so make sure to pipe stderr to a file (for example, run + `python tinymodel.py 2> tinymodel.stderr` on Linux). Returns: An MLIR module that contains the converted model in the specified @@ -389,13 +396,13 @@ def compile(model: torch.nn.Module, strip_overloads(model) # Get the model as JIT IR (TorchScript) for import. - # TODO: Longer-term, we probably need to split `torch_mlir.compile`. + # TODO: Longer-term, we probably need to split `torchscript.compile`. # There should be an "acquisition" step that does # tracing/scripting/importing from FX/using torchdynamo.export/etc. # + any lowering to the backend contract. Then there should be a # "backend lowering" step that does the actual lowering to each # backend. This separation should be visible at the Python API level, and - # we can implement a deliberately simplified API like `torch_mlir.compile` + # we can implement a deliberately simplified API like `torchscript.compile` # on top of those building blocks. if isinstance(model, torch.jit.ScriptModule): # If the user already converted the model to JIT IR themselves, just @@ -456,6 +463,7 @@ def compile(model: torch.nn.Module, mb.module, f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", "Lowering TorchScript IR -> Torch Backend IR", + enable_ir_printing=enable_ir_printing, ) return _lower_mlir_module(verbose, output_type, mb.module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py index 4ca4c3dce803..b11c242db2cb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py @@ -6,6 +6,7 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig +from .onnx_backend import OnnxBackendTestConfig from .torchscript import TorchScriptTestConfig from .stablehlo_backend import StablehloBackendTestConfig from .tosa_backend import TosaBackendTestConfig diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py index 6ad41dd6dccb..8c99278b0ec3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -30,7 +30,7 @@ def __init__(self, backend: LinalgOnTensorsBackend): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile( + module = torchscript.compile( program, example_args, output_type="linalg-on-tensors") return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py new file mode 100644 index 000000000000..e411a7cbb67f --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -0,0 +1,101 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from pathlib import Path +from typing import Any + +import io +import onnx +import torch +import torch_mlir + +from torch_mlir_e2e_test.onnx_backends.abc import OnnxBackend +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders +from .utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, +) + +from torch_mlir.extras import onnx_importer +from torch_mlir.dialects import torch as torch_d +from torch_mlir.ir import Context, Module + + +def import_onnx(contents): + # Import the ONNX model proto from the file contents: + raw_model = onnx.load_from_string(contents) + model_proto = onnx.shape_inference.infer_shapes(raw_model) + + # Import the ONNX module into an MLIR module: + context = Context() + torch_d.register_dialect(context) + model_info = onnx_importer.ModelInfo(model_proto) + m = model_info.create_module(context=context) + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m.operation) + imp.import_all() + return m + + +def convert_onnx(model, inputs): + buffer = io.BytesIO() + + # Process the type information so we export with the dynamic shape information + examples = [] + input_names = [] + dynamic_tensors = {} + for (index, arg) in enumerate(inputs): + shape = map(lambda d : d if d >= 0 else 1, arg.shape) + shape = tuple(shape) + examples.append(torch.zeros(size=shape, dtype=arg.dtype)) + + input_name = "input_{}".format(index) + input_names.append(input_name) + + dynamic_dims = {} + for (dimindex, dim) in enumerate(arg.shape): + if (dim < 0): + dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) + + if (dynamic_dims): + dynamic_tensors[input_name] = dynamic_dims + + + examples=tuple(examples) + torch.onnx.export(model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors) + buffer = buffer.getvalue() + return import_onnx(buffer) + +class OnnxBackendTestConfig(TestConfig): + """Base class for TestConfig's that are implemented with ONNX. + + This class handles all the common lowering that torch-mlir does before + reaching the ONNX abstraction level. + """ + def __init__(self, backend: OnnxBackend, use_make_fx: bool = False): + super().__init__() + self.backend = backend + self.use_make_fx = use_make_fx + + def compile(self, program: torch.nn.Module) -> Any: + example_args = convert_annotations_to_placeholders(program.forward) + onnx_module = convert_onnx(program, example_args) + compiled_module = self.backend.compile(onnx_module) + return compiled_module + + + + def run(self, artifact: Any, trace: Trace) -> Trace: + backend_module = self.backend.load(artifact) + result: Trace = [] + for item in trace: + numpy_inputs = recursively_convert_to_numpy(item.inputs) + outputs = getattr(backend_module, "main_graph")(*numpy_inputs) + output = recursively_convert_from_numpy(outputs) + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) + return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py index 45f32bb0b3fe..1ab8a8d22b4f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -18,10 +18,10 @@ class StablehloBackendTestConfig(TestConfig): - """Base class for TestConfig's that are implemented with linalg-on-tensors. + """Base class for TestConfig's that are implemented with StableHLO. This class handles all the common lowering that torch-mlir does before - reaching the linalg-on-tensors abstraction level. + reaching the StableHLO abstraction level. """ def __init__(self, backend: StablehloBackend): @@ -30,7 +30,7 @@ def __init__(self, backend: StablehloBackend): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile(program, example_args, output_type="stablehlo") + module = torchscript.compile(program, example_args, output_type="stablehlo") return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index c53227acf36a..bdc410741cae 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -17,7 +17,7 @@ from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func from torch_mlir.dynamo import _get_decomposition_table -from torch_mlir import ( +from torch_mlir.torchscript import ( _example_args, OutputType, BACKEND_LEGAL_OPS, @@ -53,6 +53,40 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: return False return True +# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to +# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument +# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it +# needs to check argument types, which are not yet determined. +# Maybe schema or target should be changed, but it decided in +# _dynamo eval_frame on pytorch side. Also Python schema not matches +# with mlir Schema - check include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +# So in general it covers some of overload cases, which done on Python side automatically. +# e.g. conversion Scalar -> Tensor and vice versa +def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule): + # Modify gm.graph + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == 'call_function': + # The target attribute is the function + # that call_function calls. + # call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {}) + if node.target == torch.ops.aten.add.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.add.Scalar + if node.target == torch.ops.aten.mul.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.mul.Scalar + + gm.graph.lint() # Does some checks to make sure the + + # Recompile the forward() method of `gm` from its Graph + gm.recompile() + def jit( model: torch.nn.Module, @@ -87,6 +121,8 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule, # way of differentiating between the two. assert not _returns_empty_tuple(gm), "encountered graph that does not return anything" + scalarize_tensor_ops_on_scalars(gm) + nonlocal mlir_module *_, model_name, nth_graph = get_aot_compilation_context() mlir_module = import_fx_graph_as_func(gm.graph, model_name) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py index 89b90567b1d4..8aa2d0e63eb6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -18,10 +18,10 @@ class TosaBackendTestConfig(TestConfig): - """Base class for TestConfig's that are implemented with linalg-on-tensors. + """Base class for TestConfig's that are implemented with TOSA. This class handles all the common lowering that torch-mlir does before - reaching the linalg-on-tensors abstraction level. + reaching the TOSA abstraction level. """ def __init__(self, backend: TosaBackend, use_make_fx: bool = False): super().__init__() @@ -30,7 +30,7 @@ def __init__(self, backend: TosaBackend, use_make_fx: bool = False): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile( + module = torchscript.compile( program, example_args, output_type="tosa", use_make_fx=self.use_make_fx) return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index f1fbad2ec914..d3fecf54d99c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -24,11 +24,19 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict from itertools import repeat +import os import sys import traceback -import torch import multiprocess as mp +from multiprocess import set_start_method +try: + set_start_method("spawn") +except RuntimeError: + # Children can error here so we suppress. + pass + +import torch TorchScriptValue = Union[int, float, List['TorchScriptValue'], Dict['TorchScriptValue', @@ -317,7 +325,15 @@ def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]: """Invoke the given `Test`'s with the provided `TestConfig`.""" - num_processes = min(int(mp.cpu_count() * 1.1), len(tests)) + num_processes = min(int(mp.cpu_count() * 0.8) + 1, len(tests)) + try: + env_concurrency = int(os.getenv("TORCH_MLIR_TEST_CONCURRENCY", "0")) + except ValueError as e: + raise ValueError("Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: " + "Expected integer.") from e + if env_concurrency > 0: + num_processes = min(num_processes, env_concurrency) + # TODO: We've noticed that on certain 2 core machine parallelizing the tests # makes the llvm backend legacy pass manager 20x slower than using a # single process. Need to investigate the root cause eventually. This is a @@ -344,7 +360,7 @@ def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=F pool = mp.Pool(num_processes) arg_list = zip(tests, repeat(config)) handles = pool.starmap_async(compile_and_run_test, arg_list) - results = handles.get() + results = handles.get(timeout=360) tests_with_results = {result.unique_name for result in results} all_tests = {test.unique_name for test in tests} diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 1b9dbb0d2c51..ad2669c51d50 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -123,6 +123,7 @@ def invoke(*args): LOWERING_PIPELINE = "builtin.module(" + ",".join([ "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", # Apply some optimizations. It would be great if MLIR had more useful # optimizations that worked out of the box here. # Note: When measured, this doesn't seem to actually help that much @@ -130,8 +131,15 @@ def invoke(*args): # This is likely because if things are naturally fusable we usually already # emit things in that form from the high level (e.g. single linalg-generic). # Other backends are likely to benefit more. + "func.func(linalg-generalize-named-ops)", "func.func(linalg-fuse-elementwise-ops)", "convert-shape-to-std", + # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum + # to ensure operations on sparse tensors are lowered to loops. + "sparse-assembler", + "sparsification-and-bufferization", + "sparse-storage-specifier-to-llvm", + "inline", # inline sparse helper methods where useful # Bufferize. "func.func(scf-bufferize)", "func.func(tm-tensor-bufferize)", @@ -195,10 +203,11 @@ def compile(self, imported_module: Module): An opaque, backend specific compiled artifact object that can be passed to `load`. """ - run_pipeline_with_repro_report( imported_module, LOWERING_PIPELINE, - "Lowering Linalg-on-Tensors IR to LLVM with RefBackend") + "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", + enable_ir_printing=False, + ) return imported_module def load(self, module) -> RefBackendInvoker: diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py new file mode 100644 index 000000000000..684c08df4fa1 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py @@ -0,0 +1,49 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import abc +from typing import TypeVar + +import torch + +from torch_mlir.ir import Module + +# A type shared between the result of `OnnxBackend.compile` and the +# input to `OnnxBackend.load`. Each backend will likely have a +# different definition of this type. +CompiledArtifact = TypeVar('CompiledArtifact') + +# A wrapper around a backend-specific loaded program representation +# that uniformly translates the `x.method(...)` interface expected of +# Torch modules into appropriate lower-level operations. +Invoker = TypeVar('Invoker') + + +class OnnxBackend(abc.ABC): + """The interface to an ONNX backend. + + Backends are recommended to raise meaningful exceptions in case of error, + ideally with easy reproduction instructions. + """ + @abc.abstractmethod + def compile(self, module: Module) -> CompiledArtifact: + """Compile the provided MLIR module into a compiled artifact. + + The module adheres to the ONNX backend contract + (see the VerifyOnnxBackendContract pass). + + The compiled artifact can be any type, but must be correctly + interpreted by the `load` method. + """ + + @abc.abstractmethod + def load(self, artifact: CompiledArtifact) -> Invoker: + """Load the compiled artifact into a uniformly invokable form. + + The compiled artifact is the result of a previous call to `compile`. + + See the description of `Invoker` for the requirements on the returned + type. + """ diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py new file mode 100644 index 000000000000..449e6bb40f01 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py @@ -0,0 +1,67 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + + +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.torchscript import OutputType +from torch_mlir.torchscript import _lower_mlir_module + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import OnnxBackend + +__all__ = [ + "LinalgOnTensorsOnnxBackend", +] + +# The pipeline of func.func passes that lower the ONNX backend contract to the +# Linalg-on-Tensors backend contract accepted by RefBackend. +ONNX_TO_TORCH_FUNC_PIPELINE = ",".join([ + "convert-torch-onnx-to-torch", +]) + + +class LinalgOnTensorsOnnxBackend(OnnxBackend): + """Main entry-point for the linalg-on-tensors based ONNX backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the ONNX backend contract. + + Args: + imported_module: The MLIR module consisting of ONNX operations wrapped by + torch.operator. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + run_pipeline_with_repro_report( + imported_module, + f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", + "Lowering Onnx backend contract to Linalg-on-Tensors backend contract") + + backend_legal_ops = ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'] + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + run_pipeline_with_repro_report( + imported_module, + f"builtin.module(torch-lower-to-backend-contract{option_string})", + "Lowering TorchFX IR -> Torch Backend IR", + ) + + imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module) + compiled_module = self.refbackend.compile(imported_module) + return compiled_module + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py new file mode 100644 index 000000000000..7dee2041c724 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -0,0 +1,56 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.compiler_utils import run_pipeline_with_repro_report + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import StablehloBackend + +__all__ = [ + "LinalgOnTensorsStablehloBackend", +] + +# The pipeline of func.func passes that lower the STABLEHLO backend contract to the +# Linalg-on-Tensors backend contract accepted by RefBackend. +STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ + "func.func(chlo-legalize-to-stablehlo)", + "canonicalize", + "stablehlo-legalize-to-linalg" +]) + + +class LinalgOnTensorsStablehloBackend(StablehloBackend): + """Main entry-point for the linalg-on-tensors based Stablehlo backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the Stablehlo backend contract. + + Args: + imported_module: The MLIR module consisting of funcs in the Stablehlo dialect. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + + run_pipeline_with_repro_report( + imported_module, + f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})", + "Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract") + + return self.refbackend.compile(imported_module) + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 2d7147955053..0d16158af887 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -10,13 +10,11 @@ from torch_mlir._version import torch_version_for_comparison, version COMMON_TORCH_MLIR_LOWERING_XFAILS = { - "NativeGroupNormModule_basic", "NativeGroupNormBackwardModule_basic", "QuantizedMLP_basic", "ReduceMaxAlongDimUnsignedInt_basic", "RepeatInterleaveModule_basic", "Im2ColModule_basic", - "ElementwiseClampIntModule_basic", "ReduceMinAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", } @@ -63,3 +61,6 @@ def register_all_tests(): from . import return_types from . import control_flow from . import stats + from . import padding + from . import diagonal + from . import gridsampler diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py index 8237d2601711..fff3e60c4605 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py @@ -62,6 +62,7 @@ def forward(self): def ArangeZeroElementOutputModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeStartIntModule(torch.nn.Module): def __init__(self): @@ -130,6 +131,7 @@ def forward(self): def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeStartStepIntModule(torch.nn.Module): def __init__(self): @@ -198,6 +200,7 @@ def forward(self): def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeDtypeFloatModule(torch.nn.Module): def __init__(self): @@ -232,6 +235,7 @@ def forward(self): def ArangeDtypeIntModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeFalsePinMemoryModule(torch.nn.Module): def __init__(self): @@ -254,7 +258,7 @@ def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils): class ArangeStartOutModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -270,7 +274,7 @@ def ArangeStartOutModule_basic(module, tu: TestUtils): class ArangeStartOutViewModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -286,7 +290,7 @@ def ArangeStartOutViewModule_basic(module, tu: TestUtils): class ArangeStartOutDtypeModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -298,3 +302,81 @@ def forward(self, x): @register_test_case(module_factory=lambda: ArangeStartOutDtypeModule()) def ArangeStartOutDtypeModule_basic(module, tu: TestUtils): module.forward(torch.zeros(12).to(torch.int64)) + +# ============================================================================== + +class LinspaceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 10) + +@register_test_case(module_factory=lambda: LinspaceModule()) +def LinspaceModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64) + + +@register_test_case(module_factory=lambda: LinspaceDtypeModule()) +def LinspaceDtypeModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceEmptyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 0) + +@register_test_case(module_factory=lambda: LinspaceEmptyModule()) +def LinspaceEmptyModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceOneSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 1) + +@register_test_case(module_factory=lambda: LinspaceOneSizeModule()) +def LinspaceOneSizeModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceTwoSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 2) + +@register_test_case(module_factory=lambda: LinspaceTwoSizeModule()) +def LinspaceTwoSizeModule_basic(module, tu: TestUtils): + module.forward() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 0d371fe37008..e5fab589258f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -552,8 +552,179 @@ def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ReflectionPad1dModule3dInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (3,1)) + +class ReplicationPad2dModule_basic_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 2, 3, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_basic_module()) +def ReplicationPad2dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_left0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (0, 2, 3, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_left0_module()) +def ReplicationPad2dModule_left0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_right0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 0, 3, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_right0_module()) +def ReplicationPad2dModule_right0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_top0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 2, 0, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_top0_module()) +def ReplicationPad2dModule_top0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_bottom0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 2, 3, 0)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_bottom0_module()) +def ReplicationPad2dModule_bottom0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInput()) +def ReflectionPad1dModule3dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(1,2,4)) +class ReflectionPad1dModule2dInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (3,2)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInput()) +def ReflectionPad1dModule2dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(2,4)) + +class ReflectionPad1dModule3dInputLeft(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 4, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (2,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInputLeft()) +def ReflectionPad1dModule3dInput_Left(module, tu: TestUtils): + module.forward(tu.rand(1,4,5)) + +class ReflectionPad1dModule2dInputRight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 6], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (0,3)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInputRight()) +def ReflectionPad1dModule2dInput_Right(module, tu: TestUtils): + module.forward(tu.rand(3,6)) + +# ============================================================================== class TransposeIntModule(torch.nn.Module): def __init__(self): @@ -3739,6 +3910,50 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100)) +# ============================================================================== + + +class FloatImplicitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ]) + def forward(self, x): + return float(torch.ops.aten.FloatImplicit(x)) + + +@register_test_case(module_factory=lambda: FloatImplicitModule()) +def FloatImplicitModule_basic(module, tu: TestUtils): + module.forward(tu.rand().double()) + + +# ============================================================================== + + +class IntImplicitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ]) + def forward(self, x): + return float(torch.ops.aten.IntImplicit(x)) + + +@register_test_case(module_factory=lambda: IntImplicitModule()) +def IntImplicitModule_basic(module, tu: TestUtils): + module.forward(tu.randint()) + + # ============================================================================== class PowIntFloat(torch.nn.Module): @@ -4112,7 +4327,13 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, val): - return torch.ops.aten.cumsum(val, 1) + # the onnx cumsum op uses a constant 1d tensor + # to specify the dimension along which to do cumsum + # we replicate that here to ensure that cumsum correctly + # trigger the relevant folders and provides TMTensor + # with a constant dimension + ones = torch.ones([1], dtype=torch.int32) + return torch.ops.aten.cumsum(val, ones.item()) @register_test_case(module_factory=lambda: CumsumModule()) def CumsumModule_basic(module, tu: TestUtils): @@ -4550,18 +4771,18 @@ def __init__(self): @export @annotate_args([ None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True) + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True) ]) def forward(self, query, key, value): return torch.ops.aten.scaled_dot_product_attention(query, key, value) @register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule()) def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils): - query = torch.randn(1, 1, 5, 5, dtype=torch.float32) - key = torch.randn(1, 1, 5, 5, dtype=torch.float32) - value = torch.randn(1, 1, 5, 5, dtype=torch.float32) + query = torch.randn(1, 5, 5, dtype=torch.float32) + key = torch.randn(1, 5, 5, dtype=torch.float32) + value = torch.randn(1, 5, 5, dtype=torch.float32) module.forward(query, key, value) class ScaledDotProductAttentionDifferentModule(torch.nn.Module): @@ -4572,18 +4793,18 @@ def __init__(self): @export @annotate_args([ None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True) + ([2, 3, 8, 4], torch.float32, True), + ([2, 3, 16, 4], torch.float32, True), + ([2, 3, 16, 4], torch.float32, True) ]) def forward(self, query, key, value): return torch.ops.aten.scaled_dot_product_attention(query, key, value) @register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule()) def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils): - query = torch.randn(3, 2, 8, 4, dtype=torch.float32) - key = torch.randn(3, 2, 16, 4, dtype=torch.float32) - value = torch.randn(3, 2, 16, 4, dtype=torch.float32) + query = torch.randn(2, 3, 8, 4, dtype=torch.float32) + key = torch.randn(2, 3, 16, 4, dtype=torch.float32) + value = torch.randn(2, 3, 16, 4, dtype=torch.float32) module.forward(query, key, value) # ============================================================================== @@ -5053,3 +5274,23 @@ def forward(self, x): @register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) def IscloseStaticModuleTrue_basic(module, tu: TestUtils): module.forward(torch.ones(5, 5)) + + +# ============================================================================== + +class CloneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.clone(x) + +@register_test_case(module_factory=lambda: CloneModule()) +def CloneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 5)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py index 5c00a75e06da..6f8240f54d89 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -28,7 +28,7 @@ def forward(self, x): for i in range(x_val): sum += i return sum - + @register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule()) def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils): @@ -50,7 +50,7 @@ def forward(self, x): while(x_val > sum): sum += 1 return sum - + @register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule()) def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index b9ba1c0947bc..453d78d2f9c8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -854,3 +854,134 @@ def forward(self, inputVec): @register_test_case(module_factory=lambda: UpSampleNearest2dSameFactor()) def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 4)) +class Conv1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d(inputVec, + weight, + bias=bias, + stride=[1], + padding=[0], + dilation=[1], + groups=1) +@register_test_case(module_factory=lambda: Conv1dModule()) +def Conv1dModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + +class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv2d(inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1) +@register_test_case(module_factory=lambda: Conv2dModule()) +def Conv2dModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6) + weight = torch.randn(8, 2, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + +class Conv3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv3d(inputVec, + weight, + bias=bias, + stride=[1, 1, 1], + padding=[0, 0, 0], + dilation=[1, 1, 1], + groups=1) +@register_test_case(module_factory=lambda: Conv3dModule()) +def Conv3dModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6, 6) + weight = torch.randn(8, 2, 3, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + +class ConvTbcModule(torch.nn.Module): + def __init__(self): + super().__init__() + + # shapes from https://github.com/pytorch/pytorch/blob/3e8c8ce37bbfaafa8581fb48506c0a70ea54463d/test/nn/test_convolution.py#L623 + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, weight, bias): + return torch.conv_tbc(x, weight, bias) + +@register_test_case(module_factory=lambda: ConvTbcModule()) +def ConvTbcModule_basic(module, tu: TestUtils): + module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6)) + +class Conv2dQInt8Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ]) + def forward(self, inputVec, weight, bias): + inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7) + inputVec = torch.dequantize(inputVec) + + weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3) + weight = torch.dequantize(weight) + + bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) + bias = torch.dequantize(bias) + + return torch.ops.aten.conv2d(inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1) +@register_test_case(module_factory=lambda: Conv2dQInt8Module()) +def Conv2dQInt8Module_basic(module, tu: TestUtils): + inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8) + weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) + bias = torch.rand(3) + module.forward(inputVec, weight, bias) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py new file mode 100644 index 000000000000..d54bd11cb7d6 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py @@ -0,0 +1,123 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + +class DiagonalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.ops.aten.diagonal(a) + + +@register_test_case(module_factory=lambda: DiagonalModule()) +def DiagonalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + +@register_test_case(module_factory=lambda: DiagonalModule()) +def DiagonalModule_nonsquare(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class DiagonalTransposedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=1, dim2=0) + +@register_test_case(module_factory=lambda: DiagonalTransposedModule()) +def DiagonalModule_transposed(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class DiagonalWithDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=0, dim2=1) + +@register_test_case(module_factory=lambda: DiagonalWithDimsModule()) +def DiagonalModule_with_dims(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class DiagonalWithNegativeDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1) + +@register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule()) +def DiagonalModule_with_negative_dims(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class DiagonalWithOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, offset=1) + +@register_test_case(module_factory=lambda: DiagonalWithOffsetModule()) +def DiagonalModule_with_offset(module, tu: TestUtils): + module.forward(tu.rand(4, 6)) + +# ============================================================================== + +class DiagonalWithDimsOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1) + +@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule()) +def DiagonalModule_with_dims_and_offset(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index d79706dfc9dd..be7945847d17 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -63,6 +63,226 @@ def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseCoshModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.cosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseCoshModule()) +def ElementwiseCoshModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseCoshIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.cosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseCoshIntModule()) +def ElementwiseCoshIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAcoshModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.acosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcoshModule()) +def ElementwiseAcoshModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAcoshIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.acosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcoshIntModule()) +def ElementwiseAcoshIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAsinModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinModule()) +def ElementwiseAsinModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAsinIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinIntModule()) +def ElementwiseAsinIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAsinhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.asinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinhModule()) +def ElementwiseAsinhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAsinhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.asinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinhIntModule()) +def ElementwiseAsinhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAtanhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.atanh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanhModule()) +def ElementwiseAtanhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAtanhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.atanh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanhIntModule()) +def ElementwiseAtanhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseBinaryModule(torch.nn.Module): def __init__(self): @@ -193,7 +413,7 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, a): - return torch.where(a > 0.5, 4.0, 8.0) + return torch.where(a > 0.5, 4.0, 8.0).to(torch.float) @register_test_case(module_factory=lambda: ElementwiseWhereScalarModule()) @@ -295,6 +515,33 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseNanToNumModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True) + ]) + def forward(self, a): + return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0) + +@register_test_case(module_factory=lambda: ElementwiseNanToNumModule()) +def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): + module.forward(torch.tensor( + [ + [float('nan'), 0.0, float('nan'), 0.0], + [float('inf'), 0.0, float('inf'), 0.0], + [float('-inf'), 0.0, float('-inf'), 0.0] + ] + )) + + +# ============================================================================== + + # Addition is an interesting special case of a binary op, because under the hood # it carries a third scalar "alpha" parameter, which needs special handling. class ElementwiseAddModule(torch.nn.Module): @@ -474,6 +721,48 @@ def forward(self, x): def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, low=-1)) + +# ============================================================================== + + +class ElementwiseLerpScalarIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.lerp(a, b, weight=2) + +@register_test_case(module_factory=lambda: ElementwiseLerpScalarIntModule()) +def ElementwiseLerpScalarIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5,3), tu.rand(5,3)) + + +class ElementwiseLerpScalarFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.lerp(a, b, weight=0.5) + +@register_test_case(module_factory=lambda: ElementwiseLerpScalarFloatModule()) +def ElementwiseLerpScalarFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5,3), tu.rand(5,3)) + + # ============================================================================== @@ -564,6 +853,50 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseGeluApproximateTanhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU(approximate="tanh") + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.gelu(x) + + +@register_test_case(module_factory=lambda: ElementwiseGeluApproximateTanhModule()) +def ElementwiseGeluApproximateTanhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-0.5, high=0.5)) + + +# ============================================================================== + + +class ElementwiseSeluModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.selu(x) + +@register_test_case(module_factory=lambda: ElementwiseSeluModule()) +def ElementwiseSeluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + class ElementwiseSigmoidModule(torch.nn.Module): def __init__(self): @@ -948,6 +1281,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampTensorInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True) + ]) + def forward(self, x): + min = -5 + max = 5 + min_clamp = torch.clamp(x, min) + max_clamp = torch.clamp(x, max=max) + both_clamp = torch.clamp(x, min=min, max=max) + return min_clamp, max_clamp, both_clamp + + +@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module()) +def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8)) + + +# ============================================================================== + + + class ElementwiseClampMinTensorFloatModule(torch.nn.Module): def __init__(self): @@ -1553,6 +1914,28 @@ def ElementwiseLog1pModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLogitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.logit(a, eps=1e-7) + + +@register_test_case(module_factory=lambda: ElementwiseLogitModule()) +def ElementwiseLogitModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseErfModule(torch.nn.Module): def __init__(self): @@ -2053,21 +2436,43 @@ def __init__(self): @export @annotate_args([ None, - ([-1, -1], torch.int32, True), + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.rsqrt(a) + + +@register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule()) +def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAbsFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), ]) def forward(self, a): - return torch.rsqrt(a) + return torch.abs(a) -@register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule()) -def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) +@register_test_case(module_factory=lambda: ElementwiseAbsFloatModule()) +def ElementwiseAbsFloatModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1.0, 0.0, 1.0]]])) # ============================================================================== -class ElementwiseAbsModule(torch.nn.Module): +class ElementwiseAbsIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2075,15 +2480,15 @@ def __init__(self): @export @annotate_args([ None, - ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), ]) def forward(self, a): return torch.abs(a) -@register_test_case(module_factory=lambda: ElementwiseAbsModule()) -def ElementwiseAbsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0)) +@register_test_case(module_factory=lambda: ElementwiseAbsIntModule()) +def ElementwiseAbsIntModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1, 0, 1]]])) # ============================================================================== @@ -2259,6 +2664,135 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils): # ============================================================================== + +class ElementwiseFmodTensor_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True) + ]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Float()) +def ElementwiseFmodTensor_Float_basic(module, tu: TestUtils): + module.forward(tu.rand(100, low=-10, high=10), tu.rand(100, low=-10, high=10)) + +# ============================================================================== + +class ElementwiseFmodTensor_Int_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.float32, True) + ]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int_Float()) +def ElementwiseFmodTensor_Int_Float_basic(module, tu: TestUtils): + module.forward(tu.randint(100, low=-10, high=10).to(torch.int32), tu.rand(100, low=-10, high=10)) + +# ============================================================================== + +class ElementwiseFmodTensor_Int(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.int32, True), + ]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int()) +def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils): + module.forward(tu.randint(100, low=0, high=1000).to(torch.int32), tu.randint(100, low=1, high=1000).to(torch.int32)) + # ============================================================================== + + +class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float()) +def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=10).to(torch.int32), tu.rand(3, 4, high=10)) + + +# ============================================================================== + + +class ElementwiseRemainderTensorModule_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Float()) +def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, high=10)) + + +# ============================================================================== + +class ElementwiseRemainderTensorModule_Int(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int32, True), + ]) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Int()) +def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=10, dtype=torch.int32), tu.randint(3, 4, high=10, dtype=torch.int32)) + +# ============================================================================== class ElementwiseDivTensorFloatModule(torch.nn.Module): @@ -2284,6 +2818,52 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseDivTensorIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int32, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule()) +def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32)) + + +# ============================================================================== + + +class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.uint8, True), + ([-1, -1], torch.uint8, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule()) +def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8)) + + +# ============================================================================== + + class ElementwiseDivRoundingModeTruncModule(torch.nn.Module): def __init__(self): @@ -3082,6 +3662,46 @@ def ElementwiseAcosIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTanModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.tan(a) + + +@register_test_case(module_factory=lambda: ElementwiseTanModule()) +def ElementwiseTanModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ElementwiseTanIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.tan(a) + + +@register_test_case(module_factory=lambda: ElementwiseTanIntModule()) +def ElementwiseTanIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + +# ============================================================================== + class ElementwiseNegModule(torch.nn.Module): def __init__(self): @@ -3418,6 +4038,83 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=2).bool()) +# ============================================================================== + +class ElementwiseAtenIsinfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isinf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule()) +def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils): + test_input = torch.tensor( + [ + [1, float('inf'), 2, float('-inf'), float('nan')], + [1, float('inf'), float('-inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + +class ElementwiseAtenIsneginfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isneginf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsneginfOpModule()) +def ElementwiseAtenIsneginfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + +class ElementwiseAtenIsposinfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isposinf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsposinfOpModule()) +def ElementwiseAtenIsposinfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + # ============================================================================== @@ -4094,10 +4791,83 @@ def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseQuantizePerTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float, True), + ]) + def forward(self, x): + scale = 0.04 + zp = -110 + dtype = torch.qint8 + # We return the int representation as we can not map to quint8 type yet on boundaries. + q = torch.quantize_per_tensor(x, scale, zp, dtype).int_repr() + return q + +@register_test_case(module_factory=lambda: ElementwiseQuantizePerTensorModule()) +def ElementwiseQuantizePerTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + +class ElementwiseDequantizePerTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8) + qx = torch.dequantize(qx) + return qx + +@register_test_case(module_factory=lambda: ElementwiseDequantizePerTensorModule()) +def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class ElementwiseDequantizePerChannelModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int8, True), + ([4], torch.int8, True), + ([4], torch.float, True), + ]) + def forward(self, x, zeropoint, scale): + qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1) + qx = torch.dequantize(qx) + return qx + +@register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule()) +def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, low=-128, high=127).to(torch.int8), + tu.rand(4) + ) + +# ============================================================================== + class GluStaticModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index ac04eeb41109..6248ef5aa32c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -779,7 +779,7 @@ def __init__(self): def forward(self): input = [True, False, True, True, False] return torch.ops.aten.all(input) - + @register_test_case(module_factory=lambda: AllBoolFalseModule()) def AllBoolFalseModule_basic(module, tu: TestUtils): module.forward() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py new file mode 100644 index 000000000000..2960041bdc68 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py @@ -0,0 +1,71 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + +class GridSamplerBasic1(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([7, 8, 12, 4], torch.float32, True), + ([7, 11, 13, 2], torch.float32, True) + ]) + def forward(self, x, g): + interpolation_mode=0, + padding_mode=0, + align_corners=True, + tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], + padding_mode[0], align_corners[0]) + return tRes + +@register_test_case( + module_factory=lambda: GridSamplerBasic1()) +def GridSamplerBasic1_basic( + module, tu: TestUtils): + inp = torch.rand(7,8,12,4) + grd = torch.rand(7,11,13,2)*2-1 + module.forward(inp, grd) + + +class GridSamplerBasic2(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 4, 4], torch.float32, True), + ([1, 1, 3, 2], torch.float32, True) + ]) + def forward(self, x, g): + interpolation_mode=0, + padding_mode=0, + align_corners=True, + tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], + padding_mode[0], align_corners[0]) + return tRes + +@register_test_case( + module_factory=lambda: GridSamplerBasic2()) +def GridSamplerBasic2_basic( + module, tu: TestUtils): + inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017], + [0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor) + grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor) + module.forward(inp, grd) + diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index e59279ab57f7..2ccd9d9d39c8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -28,7 +28,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulDot()) def Matmul_dot(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(3)) - + # ============================================================================== class Matmul2D(torch.nn.Module): @@ -48,7 +48,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: Matmul2D()) def Matmul_2d(module, tu: TestUtils): module.forward(tu.rand(3, 4), tu.rand(4, 5)) - + # ============================================================================== class MatmulVecMat(torch.nn.Module): @@ -68,7 +68,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulVecMat()) def Matmul_vecmat(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4, 5)) - + # ============================================================================== class MatmulMatVec(torch.nn.Module): @@ -88,7 +88,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulMatVec()) def Matmul_matvec(module, tu: TestUtils): module.forward(tu.rand(4, 5), tu.rand(5)) - + # ============================================================================== class Matmul3D(torch.nn.Module): @@ -108,7 +108,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: Matmul3D()) def Matmul_3d(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) - + # ============================================================================== class Matmul4d(torch.nn.Module): @@ -128,7 +128,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: Matmul4d()) def Matmul_4d(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) - + # ============================================================================== class Matmul4dStatic(torch.nn.Module): @@ -151,6 +151,26 @@ def Matmul4dStatic_basic(module, tu: TestUtils): # ============================================================================== +class Matmul4dStaticBroadcast(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 6, 2], torch.float32, True), + ([10, 10, 2, 6], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: Matmul4dStaticBroadcast()) +def Matmul4dStaticBroadcast_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 6, 2), tu.rand(10, 10, 2, 6)) + +# ============================================================================== + class MatmulStaticBroadcast(torch.nn.Module): def __init__(self): super().__init__() @@ -188,7 +208,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulSingleDynamicBatchDim()) def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) - + # ============================================================================== class MatmulBroadcastBatchDim(torch.nn.Module): @@ -208,7 +228,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulBroadcastBatchDim()) def MatmulBroadcastBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6)) - + # ============================================================================== class Mv(torch.nn.Module): @@ -262,3 +282,141 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: AtenMmIntTypes()) def AtenMmIntTypes_basic(module, tu: TestUtils): module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100)) + + +# ============================================================================== + +class AtenMmQuint8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int8, True), + ([4, 3], torch.int8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.1, 8) + qy = torch.dequantize(qy) + qz = torch.mm(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMmQuint8()) +def AtenMmQuint8_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, 3, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class AtenLinalgCrossInt(torch.nn.Module): + + @export + @annotate_args([ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.int64, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossInt()) +def AtenLinalgCrossInt_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3), tu.randint(2, 3)) + +# ============================================================================== + +class AtenLinalgCrossFloat(torch.nn.Module): + + @export + @annotate_args([ + None, + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossFloat()) +def AtenLinalgCrossFloat_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.rand(2, 3)) + + +# ============================================================================== + +class AtenLinalgCrossBroadcast(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3], torch.float32, True), + ([5, 4, 3], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossBroadcast()) +def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3)) + +# ============================================================================== + +class AtenLinalgCrossCustomDim(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=2) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossCustomDim()) +def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + +# ============================================================================== + +class AtenLinalgCrossNegativeDim(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=-3) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossNegativeDim()) +def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + +# ============================================================================== + +class AtenLinalgCrossDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=1) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) +def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index f59695620064..56821fb694f3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -130,7 +130,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001, cudnn_enabled=False) @@ -156,7 +156,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) @@ -182,7 +182,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) @@ -208,7 +208,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) @@ -233,7 +233,7 @@ def __init__(self): ]) def forward(self, x, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, None, bias, running_mean, running_var, training=False, + x, None, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) @@ -243,6 +243,42 @@ def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils): # ============================================================================== +class GroupNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 6, 7], torch.float32, True), + ([4], torch.float32, True), + ([4], torch.float32, True), + ]) + def forward(self, x, weight, bias): + return torch.ops.aten.group_norm(x, 2, weight, bias, 1.0000000000000001e-05, False) + +@register_test_case(module_factory=lambda: GroupNormModule()) +def GroupNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4)) + +class GroupNormNoWeightAndBiasModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 6, 7], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.group_norm(x, 2, None, None, 1.0000000000000001e-05, False) + +@register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule()) +def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6, 7)) + +# ============================================================================== + class NativeGroupNormModule(torch.nn.Module): def __init__(self): super().__init__() @@ -257,13 +293,15 @@ def __init__(self): def forward(self, x, weight, bias): return torch.ops.aten.native_group_norm( x, weight, bias, - 2, 6, 4, 3, 0.000001); + 2, 6, 4, 3, 0.000001) @register_test_case(module_factory=lambda: NativeGroupNormModule()) def NativeGroupNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6)) +# ============================================================================== + class NativeGroupNormBackwardModule(torch.nn.Module): def __init__(self): super().__init__() @@ -280,7 +318,7 @@ def __init__(self): def forward(self, grad_out, x, mean, rstd, weight): return torch.ops.aten.native_group_norm_backward( grad_out, x, mean, rstd, weight, - 2, 6, 4, 3, [True, True, True]); + 2, 6, 4, 3, [True, True, True]) @register_test_case(module_factory=lambda: NativeGroupNormBackwardModule()) @@ -450,3 +488,22 @@ def forward(self, x): @register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule()) def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 3)) + +class AtenInstanceNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 1, 3], torch.float32, True), + ([2], torch.float32, True), + ([2], torch.float32, True) + ]) + def forward(self, x, w, b): + return torch.ops.aten.instance_norm(x, w, b, None, + None, True, 0.0, 1e-05, False) + +@register_test_case(module_factory=lambda: AtenInstanceNormModule()) +def AtenInstanceNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py new file mode 100644 index 000000000000..59961fedcc27 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -0,0 +1,111 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import functorch +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + +class ReflectionPad2dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (10,10,10,10)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModule()) +def ReflectionPad2dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, low=-1)) + +# ============================================================================== + +class ReflectionPad2dModuleTop(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 3, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,0,2,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop()) +def ReflectionPad2dModule_Top(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 4)) + +# ============================================================================== + +class ReflectionPad2dModuleBottom(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 10, 10], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,0,0,5)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom()) +def ReflectionPad2dModule_Bottom(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 10, 10)) + +# ============================================================================== + +class ReflectionPad2dModuleLeft(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (15,0,0,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft()) +def ReflectionPad2dModule_Left(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20)) + +# ============================================================================== + +class ReflectionPad2dModuleRight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,11,0,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) +def ReflectionPad2dModule_Right(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index dd18545b0bc4..22ff3bb330ad 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -11,7 +11,6 @@ # ============================================================================== - class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): @@ -55,7 +54,6 @@ def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 7, 7)) - class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): @@ -213,6 +211,154 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) +# ============================================================================== + +class MaxPool3dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[4, 4, 4], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=1) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dModule()) +def MaxPool3dModule_basic(module, tu: TestUtils): + module.forward(torch.arange(8*8*8).view(1, 1, 8, 8, 8).float()) + +class MaxPool3dRandomSimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[4, 4, 4], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=1) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dRandomSimpleModule()) +def MaxPool3dModuleRandomSimple_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) + +class MaxPool3dLargeDataModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[6, 8, 8], + stride=[2, 2, 2], + padding=[3, 4, 4], + dilation=2) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dLargeDataModule()) +def MaxPool3dLargeDatadModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) + +class MaxPool3dEmptyStrideStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([1, 1, 20, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.max_pool3d(x, kernel_size=2, stride=[]) + + +@register_test_case(module_factory=lambda: MaxPool3dEmptyStrideStaticModule()) +def MaxPool3dEmptyStrideStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) + + +class MaxPool3dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1]) + @export + @annotate_args([ + None, + ([1, 64, 112, 112, 112], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dStaticModule()) +def MaxPool3dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112, 112, 112)) + +class MaxPool3dStaticCeilModeTrueModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + ceil_mode=True) + + @export + @annotate_args([ + None, + ([1, 64, 112, 112, 112], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dStaticCeilModeTrueModule()) +def MaxPool3dStaticCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112, 112, 112)) + + +class MaxPool3dCeilModeTrueModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[6, 8, 8], + stride=[2, 2, 2], + padding=[3, 4, 4], + dilation=2, + ceil_mode=True) + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + +@register_test_case(module_factory=lambda: MaxPool3dCeilModeTrueModule()) +def MaxPool3dCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=0.5, high=1.0)) + + # ============================================================================== @@ -701,6 +847,28 @@ def forward(self, x): def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) +class AvgPool2dWithoutPadModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8], + stride=[2, 2], + padding=[0, 0], + ceil_mode=False, + count_include_pad=False, + divisor_override=None) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.ap2d(x) + +@register_test_case(module_factory=lambda: AvgPool2dWithoutPadModule()) +def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) # ============================================================================== @@ -776,12 +944,71 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class AdaptiveAvgPool1dStaticLargerOutput(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=13) + + @export + @annotate_args([ + None, + ([5, 512, 7], torch.float32, True) + ]) + def forward(self,x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dStaticLargerOutput()) +def AdaptiveAvgPool1dStaticLargerOutput_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 512, 7)) + +class AdaptiveAvgPool1dStaticEvenMultiple(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) + + @export + @annotate_args([ + None, + ([5, 512, 147], torch.float32, True) + ]) + def forward(self,x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dStaticEvenMultiple()) +def AdaptiveAvgPool1dStaticEvenMultiple_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 512, 147)) + +class AdaptiveAvgPool1dGeneralDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) + + @export + @annotate_args([ + None, + ([-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dGeneralDynamic()) +def AdaptiveAvgPool1dGeneralDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10)) class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export @annotate_args([ @@ -801,7 +1028,7 @@ class AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export @annotate_args([ @@ -821,7 +1048,7 @@ class AdaptiveAvgPool1dUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1) @export @annotate_args([ @@ -841,7 +1068,7 @@ class AdaptiveAvgPool1dUnitOutputSizeDynamicModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1) @export @annotate_args([ @@ -855,4 +1082,85 @@ def forward(self, x): module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeDynamicModule()) def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): - module.forward(tu.rand(1, 512, 7)) \ No newline at end of file + module.forward(tu.rand(1, 512, 7)) + +class AdaptiveMaxPool2dDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dDynamic()) +def AdaptiveMaxPool2dDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) + +class AdaptiveMaxPool2dDynamicWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dDynamicWithIndices()) +def AdaptiveMaxPool2dDynamicWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) + + +class AdaptiveMaxPool2dStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + + @export + @annotate_args([ + None, + ([1, 512, 10, 9], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dStatic()) +def AdaptiveMaxPool2dStatic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 9)) + +class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + + @export + @annotate_args([ + None, + ([1, 512, 10, 16], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices()) +def AdaptiveMaxPool2dStaticWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 585a68e55af4..af7712f258b2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -335,6 +335,78 @@ def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAllDimEmpty(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=0, keepdim=False) + +@register_test_case(module_factory=lambda: ReduceAllDimEmpty()) +def ReduceAllDimEmpty_basic(module, tu: TestUtils): + module.forward(torch.tensor([])) + +# ============================================================================== + +class ReduceAllDimFloat(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=True) + +@register_test_case(module_factory=lambda: ReduceAllDimFloat()) +def ReduceAllDimFloat_basic(module, tu: TestUtils): + module.forward(torch.tensor([[5.0,1e-6,-5.0],[0,5.0,0]])) + +# ============================================================================== + +class ReduceAllDimInt(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=True) + +@register_test_case(module_factory=lambda: ReduceAllDimInt()) +def ReduceAllDimInt_basic(module, tu: TestUtils): + module.forward(torch.tensor([[5,-5,0],[5,1e10,5]]).to(torch.int32)) + +# ============================================================================== + +class ReduceAllDimBool(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=False) + +@register_test_case(module_factory=lambda: ReduceAllDimBool()) +def ReduceAllDimBool_basic(module, tu: TestUtils): + module.forward(torch.tensor([[True, False, True], [True, True, True]])) + +# ============================================================================== + class ReduceMaxAlongDim(torch.nn.Module): def __init__(self): super().__init__() @@ -845,7 +917,7 @@ def __init__(self): @export @annotate_args([ - None, + None, ([-1, -1], torch.float32, True), ]) def forward(self, a): @@ -927,7 +999,7 @@ def __init__(self): @export @annotate_args([ - None, + None, ([-1, -1], torch.float32, True), ]) def forward(self, a): @@ -1047,6 +1119,25 @@ def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): # ============================================================================== +class NormScalarModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.p = 3.0 + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.norm(a, self.p) + +@register_test_case(module_factory=lambda: NormScalarModule()) +def NormScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class NormScalarOptDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1087,8 +1178,8 @@ def NormScalarOptDimKeepDimModule_basic(module, tu: TestUtils): class ReduceFrobeniusNormModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), @@ -1105,8 +1196,8 @@ def ReduceFrobeniusNormModule_basic(module, tu: TestUtils): class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), @@ -1123,8 +1214,8 @@ def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils): class LinalgVectorNormModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), @@ -1141,8 +1232,8 @@ def LinalgVectorNormModule_basic(module, tu: TestUtils): class LinalgVectorNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), @@ -1156,6 +1247,42 @@ def LinalgVectorNormKeepDimModule_basic(module, tu: TestUtils): # ============================================================================== +class LinalgNormModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=False) + +@register_test_case(module_factory=lambda: LinalgNormModule()) +def LinalgNormModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class LinalgNormKeepDimModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=True) + +@register_test_case(module_factory=lambda: LinalgNormKeepDimModule()) +def LinalgNormKeepDimModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + class MseLossNoReductionModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1249,3 +1376,56 @@ def forward(self, input, target): @register_test_case(module_factory=lambda: CrossEntropyLossNoReductionModule()) def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(8, 2), tu.randint(8, high=2)) + +# ============================================================================== + +class TraceModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.trace(a) + +@register_test_case(module_factory=lambda: TraceModule()) +def TraceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + +@register_test_case(module_factory=lambda: TraceModule()) +def TraceModule_nonsquare(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +@register_test_case(module_factory=lambda: TraceModule()) +def TraceModule_empty(module, tu: TestUtils): + module.forward(torch.empty(0,0)) + +# ============================================================================== + +class TraceIntModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.trace(a) + +@register_test_case(module_factory=lambda: TraceIntModule()) +def TraceSignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, low=-10, high=10)) + +@register_test_case(module_factory=lambda: TraceIntModule()) +def TraceUnsignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, low=0, high=10)) + +@register_test_case(module_factory=lambda: TraceIntModule()) +def TraceUnsignedIntModule_empty(module, tu: TestUtils): + module.forward(tu.randint(0, 0, low=0, high=10)) + diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index a73435c3c1ad..73371058cf46 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -708,8 +708,8 @@ def UnsafeView1DFoldModule_basic(module, tu: TestUtils): class ReshapeAsModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([4, 3], torch.float32, True), diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index 1baa462462f1..2b8e186ff401 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils): # ============================================================================== +class ExponentialModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.exponential(x, 3.0) + mean = torch.mean(a) + std = torch.std(a) + return mean, std + + +@register_test_case(module_factory=lambda: ExponentialModule()) +def ExponentialModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(512, 512, 16).double()) + +# ============================================================================== + class BernoulliModule(torch.nn.Module): def __init__(self): super().__init__() @@ -582,3 +605,24 @@ def forward(self, x): @register_test_case(module_factory=lambda: RandnLikeDtypeModule()) def RandnLikeDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(256, 1024).double()) +# ============================================================================== + +class NormalFunctionalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0) + mean = torch.mean(a) + std = torch.std(a) + return mean, std + + +@register_test_case(module_factory=lambda: NormalFunctionalModule()) +def NormalFunctionalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2048, 4096).double()) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 74717d99fb4e..51b9fb993088 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -78,6 +78,28 @@ def SubFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand().double(), tu.rand().double()) +# ============================================================================== + +class MulFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) * float(rhs) + + +@register_test_case(module_factory=lambda: MulFloatModule()) +def MulFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand().double(), tu.rand().double()) + + # ============================================================================== @@ -428,3 +450,41 @@ def forward(self, val): @register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule()) def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + +# ============================================================================== + +class AtenItemIntOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int8, True), + ]) + + def forward(self, val): + return int(val) + +@register_test_case(module_factory=lambda: AtenItemIntOpModule()) +def AtenItemIntOpModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + +# ============================================================================== + +class AtenItemFpOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float, True), + ]) + + def forward(self, val): + return float(val) + +@register_test_case(module_factory=lambda: AtenItemFpOpModule()) +def AtenItemFpOpModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index 176ad8506b53..7753b3139a41 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1195,3 +1195,4 @@ def forward(self, input, index1, index2, value): module_factory=lambda: IndexPutImplIndexWithNoneModule()) def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7)) + diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index b13f23a1c014..7c39b3cd7f08 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -518,7 +518,7 @@ def __init__(self): ]) def forward(self, x): return torch.ops.aten.narrow(x, dim=0, start=0, length=2) - + @register_test_case(module_factory=lambda: NarrowHorizontalTest()) def NarrowHorizontalTest_basic(module, tu: TestUtils): @@ -557,7 +557,7 @@ def __init__(self): ]) def forward(self, x): return torch.ops.aten.narrow(x, dim=0, start=0, length=2) - + @register_test_case(module_factory=lambda: NarrowHorizontalTest2()) def NarrowHorizontalTest2_basic(module, tu: TestUtils): @@ -892,7 +892,7 @@ def SplitTensorGetItem_Module_basic(module, tu: TestUtils): class SplitTensorListUnpackModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -1051,3 +1051,25 @@ def forward(self, x): @register_test_case(module_factory=lambda: ChunkListUnpackUnevenDynamic_Module()) def ChunkListUnpackUnevenDynamic_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 13, 2)) + +# ============================================================================== + +class SplitWithSizes_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, -1, -1], torch.float32, True), + ]) + def forward(self, x): + split = torch.split(x, [2, 1, 2], dim=0) + return split[0], split[1], split[2] + +@register_test_case(module_factory=lambda: SplitWithSizes_Module()) +def SplitWithSizes_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 2)) + + + diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py index 8b7cf957ac78..078f3483bed8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py @@ -96,7 +96,7 @@ def forward(self, a): module_factory=lambda: SqueezeDimStaticModule()) def SqueezeDimModule_static(module, tu: TestUtils): module.forward(tu.rand(1, 7)) - + # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 6e04c5fa8700..9c85eb873326 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -275,7 +275,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: TypeAsDifferentModule()) def TypeAsDifferentModule_basic(module, tu: TestUtils): module.forward( - tu.randint(3, 5, low=0, high=10, dtype=torch.int), + tu.randint(3, 5, low=0, high=10, dtype=torch.int), tu.randint(3, 5, low=0, high=10, dtype=torch.int64) ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/utils.py b/projects/pt1/python/torch_mlir_e2e_test/utils.py index 403c455cba64..e3a76581f668 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/utils.py +++ b/projects/pt1/python/torch_mlir_e2e_test/utils.py @@ -3,13 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from torch_mlir import TensorPlaceholder +from torch_mlir.torchscript import TensorPlaceholder from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME def convert_annotations_to_placeholders(forward_method): """Converts the annotations on a forward method into tensor placeholders. - These placeholders are suitable for being passed to `torch_mlir.compile`. + These placeholders are suitable for being passed to `torchscript.compile`. """ annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) placeholders = [] diff --git a/projects/pt1/test/python/custom_op_shape_dtype_fn.py b/projects/pt1/test/python/custom_op_shape_dtype_fn.py index a46f1c594031..a3a2b965d655 100644 --- a/projects/pt1/test/python/custom_op_shape_dtype_fn.py +++ b/projects/pt1/test/python/custom_op_shape_dtype_fn.py @@ -5,7 +5,7 @@ import torch import torch.multiprocessing as mp import torch.utils.cpp_extension -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.annotations import export, annotate_args @@ -56,7 +56,7 @@ def run(): mod = CustomOpExampleModule() mod.eval() - module = torch_mlir.compile( + module = torchscript.compile( mod, torch.ones(3, 4), output_type="torch", diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py index 26eaa5bd0cb1..0979d04228b5 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py @@ -33,7 +33,10 @@ def forward(self, tensor): try: annotator.annotateArgs(class_type, ['forward'], [None]) except Exception as e: - # CHECK: Arg annotations should have one entry per function parameter (including self). + # CHECK: There must be one argument annotation per function parameter. + # CHECK-SAME: Including 'self' the number of argument annotations is: 1. + # CHECK-SAME: The number of function parameters is: 2. + # CHECK-SAME: The function signature is (__torch__.TestModule self, Tensor tensor) print(e) try: diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py index eb6bb2f09ff3..533ef7586748 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py @@ -1,5 +1,5 @@ import torch -import torch_mlir +from torch_mlir import torchscript # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s @@ -39,6 +39,6 @@ def forward(self, data): with torch.no_grad(): return data -output_type = torch_mlir.OutputType.RAW -mod = torch_mlir.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) +output_type = torchscript.OutputType.RAW +mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) print(mod) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt new file mode 100644 index 000000000000..e52135599864 --- /dev/null +++ b/python/CMakeLists.txt @@ -0,0 +1,118 @@ +# Disables generation of "version soname" (i.e. libFoo.so.), which +# causes pure duplication as part of Python wheels. +set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON) + +# The directory at which the Python import tree begins. +# See documentation for `declare_mlir_python_sources`'s ROOT_DIR +# argument. +set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") + + +# We vendor our own MLIR instance in the `torch_mlir` namespace. +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") + +################################################################################ +# Sources +################################################################################ + +declare_mlir_python_sources(TorchMLIRPythonSources) +declare_mlir_python_sources(TorchMLIRPythonExtensions) + +declare_mlir_python_sources(TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + TD_FILE dialects/TorchBinding.td + SOURCES dialects/torch/__init__.py + DIALECT_NAME torch +) + +declare_mlir_python_sources(TorchMLIRPythonSources.Importers + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + extras/fx_importer.py + extras/onnx_importer.py +) + +declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + fx.py + extras/fx_decomp_util.py +) + +declare_mlir_python_sources(TorchMLIRPythonSources.Tools + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + tools/import_onnx/__main__.py +) + +declare_mlir_python_sources(TorchMLIRSiteInitialize + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + _mlir_libs/_site_initialize_0.py +) + +################################################################################ +# Extensions +################################################################################ + +declare_mlir_python_extension(TorchMLIRPythonExtensions.Main + MODULE_NAME _torchMlir + ADD_TO_PARENT TorchMLIRPythonExtensions + SOURCES + TorchMLIRModule.cpp + EMBED_CAPI_LINK_LIBS + TorchMLIRCAPI + PRIVATE_LINK_LIBS + LLVMSupport +) + +################################################################################ +# Generate packages and shared library +# Downstreams typically will not use these, but they are useful for local +# testing. +################################################################################ + +set(_source_components + # TODO: Core is now implicitly building/registering all dialects, increasing + # build burden by ~5x. Make it stop. + # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes + # for the reference backend, but logically they can be separate. But seemingly + # the only way to handle that is to create a separate mlir python package + # tree, which seems excessive. + MLIRPythonSources + MLIRPythonExtension.Core + MLIRPythonExtension.RegisterEverything + TorchMLIRPythonSources + TorchMLIRPythonExtensions + TorchMLIRSiteInitialize + + # Sources related to optional Torch extension dependent features. Typically + # empty unless if project features are enabled. + TorchMLIRPythonTorchExtensionsSources +) + +add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI + INSTALL_COMPONENT TorchMLIRPythonModules + INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs + OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" + RELATIVE_INSTALL_ROOT ".." + DECLARED_SOURCES ${_source_components} +) + +add_mlir_python_modules(TorchMLIRPythonModules + ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir" + INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir" + DECLARED_SOURCES ${_source_components} + COMMON_CAPI_LINK_LIBS + TorchMLIRAggregateCAPI + ) diff --git a/projects/pt1/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp similarity index 100% rename from projects/pt1/python/TorchMLIRModule.cpp rename to python/TorchMLIRModule.cpp diff --git a/python/torch_mlir/_mlir_libs/_site_initialize_0.py b/python/torch_mlir/_mlir_libs/_site_initialize_0.py new file mode 100644 index 000000000000..3b93b1fa930d --- /dev/null +++ b/python/torch_mlir/_mlir_libs/_site_initialize_0.py @@ -0,0 +1,9 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# Multi-threading rarely helps the frontend and we are also running in contexts +# where we want to run a lot of test parallelism (and nproc*nproc threads +# puts a large load on the system and virtual memory). +disable_multithreading = True diff --git a/projects/pt1/python/torch_mlir/dialects/TorchBinding.td b/python/torch_mlir/dialects/TorchBinding.td similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/TorchBinding.td rename to python/torch_mlir/dialects/TorchBinding.td diff --git a/projects/pt1/python/torch_mlir/dialects/torch/__init__.py b/python/torch_mlir/dialects/torch/__init__.py similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/__init__.py rename to python/torch_mlir/dialects/torch/__init__.py diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py new file mode 100644 index 000000000000..47a79f95597e --- /dev/null +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -0,0 +1,50 @@ +import torch +from torch._decomp import get_decompositions + +# default decompositions pulled from SHARK / torch._decomp +DEFAULT_DECOMPOSITIONS = [ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes, + torch.ops.aten.native_layer_norm, + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar, + torch.ops.aten.t, + torch.ops.aten.addmm, + # decompositions that aid us in handling nn.BatchNorm2d + torch.ops.aten._native_batch_norm_legit_functional, + torch.ops.aten._native_batch_norm_legit_no_training, + torch.ops.aten._native_batch_norm_legit, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten.squeeze.dims, + # decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions + torch.ops.aten.soft_margin_loss, + torch.ops.aten.im2col, + torch.ops.aten._euclidean_dist, + torch.ops.aten.index_copy, + torch.ops.aten.index_copy_, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.log_sigmoid_forward, + torch.ops.aten.unsafe_split.Tensor, + torch.ops.aten.binary_cross_entropy, + torch.ops.aten.dot, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._prelu_kernel, + torch.ops.aten.full, + torch.ops.aten._log_softmax, + torch.ops.aten.nll_loss_forward, + torch.ops.aten.nll_loss_backward, + torch.ops.aten._to_copy, + torch.ops.aten._log_softmax_backward_data, + torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten._unsafe_index.Tensor, +] + +def get_decomposition_table(): + return get_decompositions(DEFAULT_DECOMPOSITIONS) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py new file mode 100644 index 000000000000..952b638c1988 --- /dev/null +++ b/python/torch_mlir/extras/fx_importer.py @@ -0,0 +1,1825 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +try: + from types import NoneType +except ImportError: + # python less than 3.10 doesn't have NoneType + NoneType = type(None) + +import logging +import operator +import re +from dataclasses import dataclass +from types import BuiltinMethodType, BuiltinFunctionType +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) +import weakref + +import numpy as np + +import torch +import torch.export +import torch.fx as torch_fx +from torch.fx.passes.shape_prop import TensorMetadata + +from torch import ( + dtype as TorchDtype, + FunctionSchema, +) + +from torch._ops import ( + OpOverload as TorchOpOverload, +) + +from torch._subclasses import ( + FakeTensor as TorchFakeTensor, +) + +from torch.fx import ( + Graph, + GraphModule, + Node, +) + +try: + from torch.export.graph_signature import InputSpec as TypingInputSpec +except ModuleNotFoundError: + # PyTorch prior to 2.3 is missing certain things we use in typing + # signatures. Just make them be Any. + if not TYPE_CHECKING: + TypingInputSpec = Any + else: + raise + +try: + import ml_dtypes +except ModuleNotFoundError: + # The third-party ml_dtypes package provides some optional + # low precision data-types. If used in this file, it is + # conditional. + ml_dtypes = None + +from torch.fx.node import ( + Argument as NodeArgument, +) + +from ..ir import ( + Attribute, + Block, + Context, + DenseElementsAttr, + DenseResourceElementsAttr, + FloatAttr, + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + FunctionType, + InsertionPoint, + IntegerAttr, + IntegerType, + RankedTensorType, + Location, + Module, + Operation, + StringAttr, + SymbolTable, + Type as IrType, + Value, +) + +from ..dialects import ( + func as func_dialect, +) + +__all__ = [ + "FxImporter", +] + +REQUIRED_DIALCTS = [ + "builtin", + "func", + "torch", +] + +TORCH_DTYPE_TO_MLIR_TYPE_ASM = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", + torch.float64: "f64", + torch.uint8: "ui8", + torch.int8: "si8", + torch.int16: "si16", + torch.int32: "si32", + torch.int64: "si64", + torch.bool: "i1", + torch.qint8: "!torch.qint8", + torch.quint8: "!torch.quint8", + torch.complex32: "complex", + torch.complex64: "complex", + torch.complex128: "complex", +} + +TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { + torch.float16: lambda: F16Type.get(), + torch.bfloat16: lambda: BF16Type.get(), + torch.float32: lambda: F32Type.get(), + torch.float64: lambda: F64Type.get(), + torch.uint8: lambda: IntegerType.get_unsigned(8), + torch.int8: lambda: IntegerType.get_signed(8), + torch.int16: lambda: IntegerType.get_signed(16), + torch.int32: lambda: IntegerType.get_signed(32), + torch.int64: lambda: IntegerType.get_signed(64), + torch.bool: lambda: IntegerType.get_signless(1), + torch.qint8: lambda: IntegerType.get_signed(8), + torch.quint8: lambda: IntegerType.get_unsigned(8), + torch.complex32: lambda: ComplexType.get(F16Type.get()), + torch.complex64: lambda: ComplexType.get(F32Type.get()), + torch.complex128: lambda: ComplexType.get(F64Type.get()), +} + +TORCH_DTYPE_TO_NPY_TYPE = { + # torch.qint8: None, # no equivalent np datatype + # torch.quint8: None, + torch.uint8: np.uint8, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.bool: np.bool_, + # torch.complex32: None, # no equivalent precision for numpy + torch.complex64: np.complex64, + torch.complex128: np.complex128, +} +if ml_dtypes is not None: + TORCH_DTYPE_TO_NPY_TYPE[torch.bfloat16] = ml_dtypes.bfloat16 + +TORCH_DTYPE_TO_INT = { + torch.uint8: 0, + torch.int8: 1, + torch.int16: 2, + torch.int32: 3, + torch.int64: 4, + torch.float16: 5, + torch.float32: 6, + torch.float64: 7, + # torch.complex_half 8 + torch.complex32: 9, + torch.complex64: 10, + torch.bool: 11, + # torch.qint8: 12, # quantized dtypes are not supported in all backends, currently we do not support them + # torch.quint8: 13, + # torch.qint32 14 + torch.bfloat16: 15, +} + +TORCH_MEMORY_FORMAT_TO_INT = { + torch.contiguous_format: 0, + torch.preserve_format: 1, + torch.channels_last: 2, + torch.channels_last_3d: 3, +} + +TORCH_LAYOUT_TO_INT = { + torch.strided: 0, + torch.sparse_coo: 1, + torch.sparse_csr: 2, + torch.sparse_csc: 3, + torch.sparse_bsr: 4, + torch.sparse_bsc: 5, +} + +PY_BUILTIN_TO_TORCH_OP = { + "truediv": torch.ops.aten.div, + "mul": torch.ops.aten.mul, + "add": torch.ops.aten.add, + "sub": torch.ops.aten.sub, + "lt": torch.ops.aten.lt, + "le": torch.ops.aten.le, + "ge": torch.ops.aten.ge, + "ne": torch.ops.aten.ne, + "gt": torch.ops.aten.gt, +} + +# torch with cuda has a __version__ that looks like "2.1.0+cu113", +# so split by + and 0 index will always give the base version +_IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0" + +# The following are maps from symbolic ops to their non symbolic equivalents. +# In <=2.1.0, imported fx graphs come with a type inspecific torch.ops.aten.sym_size +# We identify it using the number of args in the node, 1 being default, 2 being int +# In the mapping below (torch.aten.sym_size, 2) indicates len(args)=2 therefore +# map to torch.aten.size.int. +# Thankfully, newer versions provide a specific torch.ops.aten.sym_size.. +# Once we drop support for <2.1.0, we can get rid of the the SYMBOLIC_TORCH_OPS +# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP + +if _IS_TORCH_2_1_OR_EARLIER: + SYMBOLIC_TORCH_OPS = { + torch.ops.aten.sym_size, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_numel, + } + + SYMBOLIC_OP_TO_TORCH_OP = { + (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, + (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, + (torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default, + (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, + (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, + } +else: + SYMBOLIC_TORCH_OPS = { + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_numel.default, + } + + SYMBOLIC_OP_TO_TORCH_OP = { + torch.ops.aten.sym_size.default: torch.ops.aten.size.default, + torch.ops.aten.sym_size.int: torch.ops.aten.size.int, + torch.ops.aten.sym_stride.default: torch.ops.aten.stride.default, + torch.ops.aten.sym_stride.int: torch.ops.aten.stride.int, + torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default, + } + + +@dataclass(frozen=True) +class SparsityMeta: + """ + Class for keeping track of sparsity meta data. + + NOTE: this will be fully replaced by + torch.fx.passes.shape_prop.SparseTensorMetadata + """ + + layout: torch.layout + batch_dim: int + sparse_dim: int + dense_dim: int + blocksize: Optional[tuple[int, int]] + pos_dtype: torch.dtype + crd_dtype: torch.dtype + + +def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: + """Returns sparse tensor encoding for the given sparse layout as string.""" + assert sparsity is not None + + # Sparse tensors have the form + # [ , , ] + # which map directly to MLIR types. + batch_dim, sparse_dim, dense_dim = ( + sparsity.batch_dim, + sparsity.sparse_dim, + sparsity.dense_dim, + ) + dim = batch_dim + sparse_dim + dense_dim + assert dim == len(shape) + blocksize = sparsity.blocksize + + dims = ",".join(f"d{d}" for d in range(0, dim)) + + if sparsity.layout is torch.sparse_coo: + assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims + lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)" + elif sparsity.layout is torch.sparse_csr: + assert sparse_dim == 2 and blocksize is None + lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" + elif sparsity.layout is torch.sparse_csc: + assert sparse_dim == 2 and blocksize is None + lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" + else: + assert sparse_dim == 2 and blocksize is not None + if sparsity.layout is torch.sparse_bsr: + i, j = batch_dim, batch_dim + 1 + else: + assert sparsity.layout is torch.sparse_bsc + j, i = batch_dim, batch_dim + 1 + m, n = blocksize + lvls = ( + f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed," + f"d{i} mod {m}:dense,d{j} mod {n}:dense" + ) + + if batch_dim > 0: + batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim)) + lvls = f"{batch},{lvls}" + + if dense_dim > 0: + dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) + lvls = f"{lvls},{dense}" + + posw = torch.iinfo(sparsity.pos_dtype).bits + crdw = torch.iinfo(sparsity.crd_dtype).bits + return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" + + +def is_symbolic(obj: Any) -> bool: + """Check whether an object in our graph is symbolic""" + return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool)) + + +def is_builtin_function_or_method(obj: Any) -> bool: + return isinstance(obj, (BuiltinMethodType, BuiltinFunctionType)) + + +# TODO: switch back to `slots=True` when py3.9 support is dropped +@dataclass(frozen=True) +class InputInfo: + """Provides additional metadata when resolving inputs.""" + + __slots__ = [ + "program", + "input_spec", + "node", + "ir_type", + "mutable_producer_node_name", + ] + + program: torch.export.ExportedProgram + input_spec: TypingInputSpec + node: Node + ir_type: IrType + mutable_producer_node_name: Optional[str] + + +class FxImporterHooks: + """Hooks to control the behavior of the FxImporter.""" + + def prepare_module(self, module_op: Operation): + """Performs any needed preparation work on the module.""" + ... + + def resolve_literal( + self, gni: "GraphNodeImporter", literal: Any + ) -> Optional[Value]: + """User overridable hook to resolve a literal value.""" + return None + + def resolve_input( + self, gni: "GraphNodeImporter", value: Any, info: InputInfo + ) -> Optional[Value]: + """Resolves a Parameter or Buffer input to an IR value. + + If the 'mutable_producer_node_name' option is set, then the result must + be a `!torch.tensor`. + Otherwise, it must be an immutable `!torch.vtensor`. If this constraint cannot + be met, the implementation must either error or return None to delegate to + the default. + """ + return None + + +class FxImporter: + """Main entry-point for importing an fx.GraphModule. + + The FxImporter is a low-level class intended for framework integrators. + It provides several options for customization: + + * config_check: Optionally allows some per-import configuration safety + checks to be skipped. + * literal_resolver_callback: Callback that will be invoked when a literal, + live torch.Tensor is encountered in the FX graph, allowing the default + action (which is to inline the data as a DenseResourceElementsAttr) to + be completely overriden. + * py_attr_tracker: Weak reference tracker for live PyTorch objects used + to unique them with respect to attributes. If not specified, there will + be one reference tracker per import, but this can be injected to share + the same uniqueing across imports (i.e. if building multiple functions + into the same context or module). + """ + + __slots__ = [ + "_c", + "_cc", + "_m", + "_m_ip", + "_py_attr_tracker", + "_hooks", + "symbol_table", + ] + + def __init__( + self, + *, + module: Optional[Module] = None, + context: Optional[Context] = None, + config_check: bool = True, + py_attr_tracker: Optional["RefTracker"] = None, + hooks: Optional[FxImporterHooks] = None, + ): + if module is not None: + assert context is None, "If configuring with a Module, context must be None" + self._m = module + self._c = self.module.context + else: + self._c = context if context else Context() + self._m = Module.create(Location.unknown(self._c)) + if config_check: + # Production code can disable this for a bit of a boost. + self._config_check() + self._py_attr_tracker = py_attr_tracker or RefTracker() + self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker) + self._m_ip = InsertionPoint(self._m.body) + self._hooks = hooks or FxImporterHooks() + self.symbol_table = SymbolTable(self._m.operation) + self._hooks.prepare_module(self._m.operation) + + def _config_check(self): + for dname in REQUIRED_DIALCTS: + try: + self._c.dialects[dname] + logging.debug("Context has registered dialect '%s'", dname) + except IndexError: + raise RuntimeError( + f"The MLIR context {self._c} is missing required dialect '{dname}'" + ) + + @property + def module(self) -> Module: + return self._m + + @property + def module_op(self) -> Operation: + return self._m.operation + + def import_program( + self, prog: torch.export.ExportedProgram, *, func_name: str = "main" + ): + """Imports an ExportedProgram according to our chosen canonical representation. + + This mechanism is the fully general solution for handling an ExportedProgram + and should eventually supercede all others. However, it depends on the + PyTorch 2.3 release to function properly (specifically, this patch + made ExportedProgram minimally correct for mutation: + https://github.com/pytorch/pytorch/pull/118969). + + For stateless programs, the result of this import is a normal function + defined for immutable `!torch.vtensors`. + + However, if the program mutates its inputs or buffers, then it will be imported + with those parameters as `!torch.tensor` and appropriate copies and overwrites + will be done on the inside. Note that the function is still mostly stateless, + but with `torch.copy.to_vtensor` and `torch.overwrite.tensor.contents` + ops at the earliest consumer or latest producer to update an argument or + buffer. + + It is recommended that integrators subclass and override the `resolve_literal` + method to control access to mutable buffers and parameters. Without that, the + default policy is to capture them as frozen values. + """ + # Create lookaside table of placeholders/outputs. + placeholder_nodes: dict[str, Node] = {} + all_producer_nodes: dict[str, Node] = {} + loc: Optional[Location] = None + for node in prog.graph.nodes: + if loc is None: + loc = self._cc.get_node_location(node) + if node.op == "placeholder": + placeholder_nodes[node.name] = node + all_producer_nodes[node.name] = node + elif node.op == "call_function": + all_producer_nodes[node.name] = node + if loc is None: + loc = Location.unknown(self._c) + + # This API is fast evolving. We keep these imports local for now so that we + # can disable this entire function if needed. + from torch.export.graph_signature import ( + InputKind, + OutputKind, + TensorArgument, + SymIntArgument, + ) + + sig = prog.graph_signature + + # Invert the (producer, node_name) maps for mutated user inputs and mutated + # buffers. This is because we hit-detect based on the input node name. + mutated_user_inputs = { + node_name: producer + for producer, node_name in sig.user_inputs_to_mutate.items() + } + + # Additional bindings that we need to set up after the function is created. + mutable_buffer_target_producers: dict[str, str] = {} + constant_tensors: dict[Node, torch.Tensor] = {} + parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {} + buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {} + + # Derive user outputs that we preserve. These will be nodes of the + # producer for the output. + user_outputs: list[Node] = [] + user_output_types: list[IrType] = [] + for output_spec in sig.output_specs: + kind = output_spec.kind + arg = output_spec.arg + if kind == OutputKind.USER_OUTPUT: + if not isinstance(arg, (TensorArgument, SymIntArgument)): + raise NotImplementedError( + f"OutputKind.USER_OUTPUT for {type(arg)}: {arg}" + ) + output_producer_node = all_producer_nodes[arg.name] + user_outputs.append(output_producer_node) + user_output_types.append( + self._cc.node_val_to_type(output_producer_node) + ) + elif kind == OutputKind.BUFFER_MUTATION and isinstance(arg, TensorArgument): + mutable_buffer_target_producers[output_spec.target] = arg.name + + # Derive user inputs. These will be op=='placeholder' nodes. + user_inputs: list[Node] = [] + user_input_types: list[IrType] = [] + for input_spec in sig.input_specs: + arg = input_spec.arg + if input_spec.kind == InputKind.USER_INPUT: + # Set up user input. + if not isinstance(arg, (TensorArgument, SymIntArgument)): + raise NotImplementedError( + f"InputKind.USER_INPUT for {type(arg)}: {arg}" + ) + placeholder_node = placeholder_nodes[arg.name] + mutable = placeholder_node.name in mutated_user_inputs + user_inputs.append(placeholder_node) + user_input_types.append( + self._cc.node_val_to_type(placeholder_node, mutable=mutable) + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR and isinstance( + arg, TensorArgument + ): + # Remember constant tensor binding. + constant_tensors[placeholder_nodes[arg.name]] = prog.constants[ + input_spec.target + ] + elif input_spec.kind == InputKind.PARAMETER and isinstance( + arg, TensorArgument + ): + # Remember parameter binding. + value = prog.state_dict.get(input_spec.target) + assert ( + not input_spec.persistent or value is not None + ), "Expected state_dict value for persistent value" + node = placeholder_nodes[arg.name] + node_ir_type = self._cc.node_val_to_type(node, mutable=False) + parameter_bindings[node] = ( + value, + InputInfo( + prog, + input_spec, + node=node, + ir_type=node_ir_type, + mutable_producer_node_name=None, + ), + ) + elif input_spec.kind == InputKind.BUFFER and isinstance( + arg, TensorArgument + ): + # Remember buffer binding. + value = prog.state_dict.get(input_spec.target) + assert ( + not input_spec.persistent or value is not None + ), "Expected state_dict value for persistent value" + node = placeholder_nodes[arg.name] + mutable_producer_node_name = mutable_buffer_target_producers.get( + input_spec.target + ) + node_ir_type = self._cc.node_val_to_type( + node, mutable=bool(mutable_producer_node_name) + ) + buffer_bindings[node] = ( + value, + InputInfo( + prog, + input_spec, + node=node, + ir_type=node_ir_type, + mutable_producer_node_name=mutable_producer_node_name, + ), + ) + else: + raise NotImplementedError( + f"InputSpec not of a known kind: {input_spec}" + ) + + ftype = FunctionType.get(user_input_types, user_output_types, context=self._c) + + # Create the function. + with loc: + func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip) + entry_block = Block.create_at_start(func_op.body, ftype.inputs) + + node_importer = GraphNodeImporter( + self, + self._c, + self._cc, + entry_block, + ) + + # Bind constants to IR values. + for constant_node, constant_tensor in constant_tensors.items(): + node_importer.import_constant(loc, constant_node, constant_tensor) + + # Bind user inputs to IR values. + for user_input_node, block_arg_value in zip(user_inputs, entry_block.arguments): + if user_input_node.name in mutated_user_inputs: + # Materialize + node_importer.import_mutable_to_vtensor( + loc, + user_input_node, + block_arg_value, + mutated_user_inputs[user_input_node.name], + ) + else: + # Normal value tensor binding. + node_importer.bind_node_value(user_input_node, block_arg_value) + + # Lazy bind buffer and parameter inputs. + for node, (parameter_value, info) in parameter_bindings.items(): + node_importer.lazy_import_parameter(loc, node, parameter_value, info) + for node, (buffer_value, info) in buffer_bindings.items(): + node_importer.lazy_import_buffer(loc, node, buffer_value, info) + + # Import all nodes and return. + node_importer.import_nodes( + all_producer_nodes.values(), skip_placeholders_outputs=True + ) + node_importer.return_node_values(loc, user_outputs) + self.symbol_table.insert(func_op) + + def import_frozen_program( + self, prog: torch.export.ExportedProgram, func_name: str = "main" + ): + """Imports a consolidated torch.export.ExportedProgram instance. + + If using the new torch.export path (vs a lower level precursor), then this is + the recommended way to canonically use this importer. + + The ExportedProgram form differs from some of the earlier work primarily in + how it deals with references to external tensors from "outside". In this form, + all such references are checked to have originated from within the exported + scope or from an @assume_constant_result wrapped function. Then they are + transformed to graph inputs and stashed in one of two data structures on + the ExportedProgram: + inputs_to_buffers / buffers : For non-parameter buffers. + inputs_to_parameters / parameters : For parameter buffers. + The values of the mapping in inputs_to_{buffers|parameters} are in the + state_dict. This replaces get_attr nodes that would have classically been + present during lower level tracing. + Historically, torch-mlir has assumed that all such external accesses are + frozen, and this entry-point preserves this behavior, treating each distinct + torch.Tensor encountered in such a way as a `torch.vtensor.literal` (or + delegating to the literal_resolver_callback to make a policy decision). + + As we anticipate more nuanced treatment options in the future, we name this + method to indicate that it is producing "frozen" modules. Additional top-level + approaches to handling state can be introduced later as an addition. + + TODO: This mechanism should be eventually replaced by `import_program` with + hooks set on the subclass to freeze parameters and buffers. However, that is + waiting for the Torch 2.3 release cut. + """ + sig = prog.graph_signature + state_dict = prog.state_dict + arg_replacements: dict[str, Any] = {} + + # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look + # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 + if hasattr(prog, "constants"): + constants = prog.constants + # Lift tensor constants. + for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items(): + try: + state_value = constants[state_name] + except KeyError as e: + raise AssertionError( + "Could not find state mapping for tensor constants" + ) from e + arg_replacements[input_name] = state_value + else: + # Lift buffers. + for input_name, state_name in sig.inputs_to_buffers.items(): + try: + state_value = state_dict[state_name] + except KeyError as e: + raise AssertionError( + "Could not find state mapping for buffer" + ) from e + arg_replacements[input_name] = state_value + + # Lift parameters. + for input_name, state_name in sig.inputs_to_parameters.items(): + try: + state_value = state_dict[state_name] + except KeyError as e: + raise AssertionError( + "Could not find state mapping for parameter" + ) from e + arg_replacements[input_name] = state_value + + # Remove any lifted placeholders, replacing their uses with the state + # replacement value. + g = prog.graph + for node in g.nodes: + if node.op == "placeholder": + replacement = arg_replacements.get(node.name) + if replacement is None: + continue + node.replace_all_uses_with(replacement) + g.erase_node(node) + + self.import_stateless_graph(g, func_name) + + def import_graph_module(self, gm: GraphModule): + """Low-level import of a GraphModule assuming that it has been functionalized. + + TODO: This mechanism is deprecated by the `import_program` entry-point and + it should be removed when no longer required for backwards compatibility. + """ + self.import_stateless_graph(gm.graph) + + def import_stateless_graph(self, g: Graph, func_name: str = "main"): + """Low-level import of a functionalized, assumed stateless Graph as a func. + + TODO: This mechanism is deprecated by the `import_program` entry-point and + it should be removed when no longer required for backwards compatibility. + """ + ftype, loc = self._graph_to_function_meta(g) + # TODO: The FuncOp constructor requires a context-manager context. + # Fix upstream and then unnest. + # See: https://github.com/nod-ai/SHARK-Turbine/issues/138 + with loc: + func = func_dialect.FuncOp( + func_name, + ftype, + ip=self._m_ip, + ) + entry_block = Block.create_at_start(func.body, ftype.inputs) + node_importer = GraphNodeImporter( + self, + self._c, + self._cc, + entry_block, + ) + node_importer.import_nodes(g.nodes) + self.symbol_table.insert(func) + + def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]: + """Extracts function metadata from the Graph. + + Principally, this includes the FunctionType, but in the future, + it should also return other annotations (input strides, etc) that + affect compilation and should be included as arg attrs. + """ + input_types = [] + result_types = [] + loc = None + for node in g.nodes: + # Assume that the first node we can get a location for is about as + # good as it gets as an overall function location. + if loc is None: + loc = self._cc.get_node_location(node) + if node.op == "placeholder": + input_types.append(self._cc.node_val_to_type(node)) + elif node.op == "output": + # An output node's args[0] is the return value. This seems to + # always be "boxed" as a tuple, which we emit as multi-results. + for result_node in node.args[0]: + if result_node is None: + result_types.append( + IrType.parse("!torch.none", context=self._c) + ) + else: + result_types.append(self._cc.node_val_to_type(result_node)) + return ( + FunctionType.get(input_types, result_types, context=self._c), + loc if loc else Location.unknown(self._c), + ) + + +class ContextCache: + """Caches per-context lookups of various things that we ask for repeatedly.""" + + __slots__ = [ + "_c", + "_dtype_to_type", + "_tensor_metadata_cache", + "_py_attr_tracker", + # Types. + "torch_bool_type", + "torch_float_type", + "torch_int_type", + "torch_none_type", + "torch_str_type", + "torch_device_type", + ] + + def __init__( + self, context: Context, *, py_attr_tracker: Optional["RefTracker"] = None + ): + self._c = context + self._dtype_to_type: Dict[TorchDtype, IrType] = {} + self._tensor_metadata_cache: Dict[ + Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType + ] = {} + self._py_attr_tracker = py_attr_tracker or RefTracker() + + # Common types. + with context: + self.torch_bool_type = IrType.parse("!torch.bool") + self.torch_float_type = IrType.parse("!torch.float") + self.torch_int_type = IrType.parse("!torch.int") + self.torch_none_type = IrType.parse("!torch.none") + self.torch_str_type = IrType.parse("!torch.str") + self.torch_device_type = IrType.parse("!torch.Device") + + def integer_attr(self, value: int, bits: int) -> Attribute: + c = self._c + return IntegerAttr.get(IntegerType.get_signless(bits, c), value) + + def format_asm_shape(self, shape: torch.Size) -> str: + """Strips symbolic elements from a torch.Size object and returns shape asm""" + return ",".join("?" if is_symbolic(d) else str(d) for d in list(shape)) + + def get_vtensor_type( + self, + shape: torch.Size, + dtype: torch.dtype, + *, + sparsity: Optional[SparsityMeta] = None, + mutable: bool = False, + ): + """Return IrType for !torch.vtensor with the given shape and dtype""" + stem = "torch.tensor" if mutable else "torch.vtensor" + shape_asm = self.format_asm_shape(shape) + mlir_dtype = str(self.dtype_to_type(dtype)) + if sparsity is not None: + encoding = sparsity_encoding(shape, sparsity) + assert encoding is not None + return IrType.parse( + f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", + context=self._c, + ) + return IrType.parse( + f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c + ) + + def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrType: + try: + tensor_meta = node.meta.get("tensor_meta") + val = node.meta.get("val") + sparsity = node.meta.get("sparsity", None) + if tensor_meta is not None: + assert isinstance(tensor_meta, TensorMetadata) + # Quantized tensor meta data is not preserved in our lowering, + # so throw error instead of silently doing wrong thing. + if tensor_meta.is_quantized: + raise NotImplementedError( + f"Quantized tensor meta data is not supported." + ) + else: + return self.tensor_metadata_to_type( + tensor_meta, sparsity=sparsity, mutable=mutable + ) + elif val is not None: + # some nodes with symbolic inputs pass a 'val' attribute rather than + # tensor_meta + if isinstance(val, TorchFakeTensor): + return self.get_vtensor_type( + val.size(), val.dtype, sparsity=sparsity, mutable=mutable + ) + + t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) + if t is not None: + return IrType.parse(t, self._c) + + raise NotImplementedError( + f"FIXME: Unsupported placeholder node (this often indicates that a necessary) " + f"fx preprocessing pass was not run): {node.meta}" + ) + except KeyError as e: + raise RuntimeError( + f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" + ) + + def tensor_metadata_to_type( + self, + tm: TensorMetadata, + *, + sparsity: Optional[SparsityMeta] = None, + mutable: bool = False, + ) -> IrType: + tm_shape = tuple( + item.node if is_symbolic(item) else item for item in list(tm.shape) + ) + + key = (tm_shape, tm.dtype, sparsity, mutable) + t = self._tensor_metadata_cache.get(key) + if t is None: + t = self.get_vtensor_type( + tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable + ) + self._tensor_metadata_cache[key] = t + return t + + def dtype_to_type(self, dtype: TorchDtype) -> IrType: + t = self._dtype_to_type.get(dtype) + if t is None: + try: + asm = TORCH_DTYPE_TO_MLIR_TYPE_ASM[dtype] + except IndexError: + raise ValueError(f"Unknown conversion from {dtype} to IREE type") + t = IrType.parse(asm, self._c) + self._dtype_to_type[dtype] = t + return t + + def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType: + dtype_asm = str(self.dtype_to_type(tensor.dtype)) + return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>") + + def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: + stack_trace = node.meta.get("stack_trace") + if stack_trace is None: + return None + # Ugh. + # TODO: Avoid needing to regex match this. + # https://github.com/pytorch/pytorch/issues/91000 + stack_trace = node.stack_trace + if stack_trace: + m = re.search(r"""File "([^"]+)", line ([0-9]+),""", stack_trace) + if m: + filename, line = m.group(1), int(m.group(2)) + return Location.file(filename, line, col=0, context=self._c) + return Location.unknown(context=self._c) + + +class GraphNodeImporter: + """Imports graph nodes into an MLIR function. + + The caller must have already created the function. + """ + + __slots__ = [ + "_b", + "_c", + "_cc", + "_on_node_produced", + "_v", + "_multi_result_nodes", + "fx_importer", + ] + + def __init__( + self, + fx_importer: FxImporter, + context: Context, + context_cache: ContextCache, + block: Block, + ): + self.fx_importer = fx_importer + self._c = context + self._cc = context_cache + self._b = block + # Map of (Node, result_index) to MLIR Value or a callback that lazily + # constructs and returns a value. + self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} + # Map of node name to hook that should be called when it is produced. + self._on_node_produced: dict[str, Callable[[Value], None]] = {} + # Statically multi-result nodes which we have de-tupled are noted here. + # They will have their getitem calls short-circuited. + self._multi_result_nodes: Set[torch_fx.Node] = set() + + def bind_node_value( + self, + node: Node, + value: Union[Value, Callable[[], Value]], + result_index: int = 0, + ): + """Binds a node to a value (and asserts if already bound). + + This is used by outside callers. Many internal callers poke directly + into the dict. + """ + key = (node, result_index) + assert key not in self._v, f"Node already has a value: {node}" + self._v[key] = value + + producer_callback = self._on_node_produced.get(node.name) + if producer_callback is not None: + producer_callback(value) + + def resolve_node_value(self, node: Node, result_index: int = 0) -> Value: + """Resolves a node to a value.""" + key = (node, result_index) + try: + binding = self._v[key] + except KeyError: + raise KeyError(f"FX Node {node} has not been bound to an MLIR value") + if isinstance(binding, Value): + return binding + + # It is a lazy callback. + value = binding() + self._v[key] = value + return value + + def import_mutable_to_vtensor( + self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str + ) -> Value: + """Imports a node that is represented by a mutable IR value. + + This will generate and associate the following with the node: + %0 = torch.copy.to_vtensor {mutable_value} + + Then it will also add a trigger such that when `producer_node_name` is + produced, the following will be generated: + torch.overwrite.tensor.contents {producer}, {mutable_value} + """ + with loc, InsertionPoint(self._b): + immutable_type = self._cc.node_val_to_type(node) + copy_result = Operation.create( + "torch.copy.to_vtensor", + results=[immutable_type], + operands=[mutable_value], + ).result + self.bind_node_value(node, copy_result) + + # Add the producer trigger. + def on_produced(value: Value): + with loc, InsertionPoint(self._b): + Operation.create( + "torch.overwrite.tensor.contents", + results=[], + operands=[value, mutable_value], + ) + + self._on_node_produced[producer_node_name] = on_produced + return copy_result + + def import_constant(self, loc: Location, node: Node, constant: Any) -> Value: + with loc, InsertionPoint(self._b): + value = self._import_literal(constant) + self.bind_node_value(node, value) + return value + + def lazy_import_parameter( + self, loc, node: Node, parameter_value: Any, info: InputInfo + ): + def _on_access() -> Value: + with loc, InsertionPoint(self._b): + # TODO: Should go to a parameter binding hook. + return self._import_input(parameter_value, info) + + self.bind_node_value(node, _on_access) + + def lazy_import_buffer( + self, + loc, + node: Node, + buffer_value: Any, + info: InputInfo, + ): + def _on_access() -> Value: + with loc, InsertionPoint(self._b): + # TODO: Should go to a buffer binding hook. + return self._import_input(buffer_value, info) + + self.bind_node_value(node, _on_access) + + if info.mutable_producer_node_name is not None: + + def on_produced(value: Value): + mutable_buffer_value = self.resolve_node_value(node) + with loc, InsertionPoint(self._b): + Operation.create( + "torch.overwrite.tensor.contents", + results=[], + operands=[value, mutable_buffer_value], + ) + + self._on_node_produced[info.mutable_producer_node_name] = on_produced + + def return_node_values(self, loc, nodes: list[Node]): + with loc, InsertionPoint(self._b): + operands = [self.resolve_node_value(n) for n in nodes] + func_dialect.ReturnOp(operands, loc=loc) + + def import_nodes( + self, nodes: Sequence[Node], *, skip_placeholders_outputs: bool = False + ): + with InsertionPoint(self._b): + loc = Location.unknown() + num_placeholders = 0 + for node in nodes: + op = node.op + # Attempt to extract locations. Not everything has them, + # so we do our best. + new_loc = self._cc.get_node_location(node) + if new_loc is not None: + loc = new_loc + if op == "placeholder" and not skip_placeholders_outputs: + # Associate the placeholder node with corresponding block + # argument. + self.bind_node_value(node, self._b.arguments[num_placeholders]) + num_placeholders += 1 + elif op == "call_function": + target = node.target + if target == operator.getitem: + # Special case handling of getitem for when it is resolving + # against a function call that we know has returned multiple + # results. We short-circuit this case because we have modeled + # function calls to natively return multiple results vs tupling. + getitem_ref, getitem_index = node.args + if getitem_ref in self._multi_result_nodes: + try: + self.bind_node_value( + node, + self.resolve_node_value(getitem_ref, getitem_index), + ) + except IndexError: + raise RuntimeError( + f"getitem de-aliasing failed. This likely " + f"indicates a programmer error that usually " + f"would have happened at runtime. Please " + f"notify developers if this case happens " + f"(at {loc})." + ) + else: + raise NotImplementedError( + f"General getitem access to non-multi-result ops" + ) + elif target in SYMBOLIC_TORCH_OPS or ( + is_symbolic(node.meta.get("val")) + and is_builtin_function_or_method(target) + ): + self._import_symbolic_torch_op(loc, node, target) + elif isinstance(target, TorchOpOverload): + # Dispatch to an ATen op. + self._import_torch_op_overload(loc, node, target) + else: + raise NotImplementedError( + f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}" + ) + elif op == "output" and not skip_placeholders_outputs: + # args[0] is a singleton tuple that we flatten into multiple + # results. + operands = [self._import_argument(loc, arg) for arg in node.args[0]] + func_dialect.ReturnOp(operands, loc=loc) + + def _promote_symbolic_scalar_int_float(self, loc, graph, param): + temp_target = torch.ops.aten.Float.Scalar + temp_node = Node( + graph=graph, + name=f"{str(param)}_as_float", + op="call_function", + target=temp_target, + args=(param,), + kwargs={}, + return_type=float, + ) + temp_node.meta["val"] = torch.sym_float(param.meta["val"]) + self._import_torch_op_overload(loc, temp_node, temp_target) + return temp_node + + def _import_symbolic_torch_op( + self, + loc: Location, + node: torch_fx.Node, + target: Union[ + torch._ops.OpOverloadPacket, BuiltinMethodType, BuiltinFunctionType + ], + ): + # parse builtin operations like add, sub, mul, etc. because dynamo captures these + # operations on symbolic arguments as regular python expressions rather than as torch ops + if is_builtin_function_or_method(target): + arg_types = [ + (arg.meta["val"].node.pytype if isinstance(arg, Node) else type(arg)) + for arg in node.args + ] + is_int = [item == int for item in arg_types] + if all(is_int): + op_overload = "int" + elif any(is_int): + if target.__name__ in ("add", "lt", "ge", "ne", "gt"): + op_overload = "float_int" + # put float arg first, as expected in signature + if arg_types[1] == float: + node.args = (node.args[1], node.args[0]) + else: + # promote int argument to float - following torch-mlir convention + arg0, arg1 = node.args + if is_int[0]: + if isinstance(arg0, Node): + prom_arg = self._promote_symbolic_scalar_int_float( + loc, node.graph, arg0 + ) + new_args = (prom_arg, arg1) + else: + arg0 = float(arg0) + new_args = (arg0, arg1) + else: + if isinstance(arg1, Node): + prom_arg = self._promote_symbolic_scalar_int_float( + loc, node.graph, arg1 + ) + new_args = (arg0, prom_arg) + else: + arg1 = float(arg1) + new_args = (arg0, arg1) + + node.args = new_args + op_overload = "float" + else: + op_overload = "float" + + torch_op = PY_BUILTIN_TO_TORCH_OP.get(target.__name__) + assert ( + torch_op is not None + ), f"Unsupported builtin function for symbolic types: {target} with args {node.args}" + concrete_target = getattr(torch_op, op_overload) + else: + if _IS_TORCH_2_1_OR_EARLIER: + concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args))) + else: + concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get(target) + + assert ( + concrete_target is not None + ), f"Unable to parse symbolic operation: {target} with args {node.args}" + self._import_torch_op_overload(loc, node, concrete_target) + + def _import_torch_op_overload( + self, loc: Location, node: torch_fx.Node, target: TorchOpOverload + ): + # replace lift_fresh_copy with clone op + if target == torch.ops.aten.lift_fresh_copy.default: + node.target = target = torch.ops.aten.clone.default + node.args = (node.args[0], None) + elif target == torch.ops.aten.lift_fresh_copy.out: + node.target = target = torch.ops.aten.clone.out + node.args = (node.args[0], None, node.args[1]) + # TODO: generalize empty.memory_format in the future + # Currently, the aten.baddbmm.default op for Unet includes multiplying an + # empty.memory_format input with a constant, which creates NaN values + # because empty.memory_format contains uninitialized data. Converting + # aten.baddbmm.default -> aten.zeros.default fixes the correctness issue + elif target == torch.ops.aten.empty.memory_format: + if len(node.users) == 1: + for key_node in node.users: + if key_node.target == torch.ops.aten.baddbmm.default: + node.target = target = torch.ops.aten.zeros.default + + schema = target._schema + assert isinstance(schema, FunctionSchema) + + # Map to a `torch` dialect name. + namespace, sep, unqualified_name = schema.name.partition("::") + assert sep, f"Malformed Torch op name {schema.name}" + mlir_op_name = f"torch.{namespace}.{unqualified_name}" + if schema.overload_name != "": + mlir_op_name += f".{schema.overload_name}" + + # Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments. + if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and ( + isinstance(node.args[1], float) or isinstance(node.args[1], int) + ): + mlir_op_name = TENSOR_SCALAR_OP_CONVERTER[mlir_op_name] + # we are dynamically changing which op is emitted here due to an issue in + # torch dynamo where it emits the Tensor variant of ops even when processing + # scalar arguments, therefore we retrieve the schema as well so that we + # consume the correct typing information when subsequently importing the + # function arguments and result types + # i.e. the code below is basically doing `schema = torch.ops.aten.my_op.Scalar._schema` + op_attrs = mlir_op_name.split(".") + op_overload = getattr(torch, "ops") + for i in range(1, len(op_attrs)): + op_overload = getattr(op_overload, op_attrs[i]) + schema = op_overload._schema + + return_count = len(schema.returns) + if return_count == 1: + # Unary return directly maps a single meta["val"] and cannot be subscripted. + # if "tensor_meta" is None, this will throw unsupported placeholder node error + result_types = [self._cc.node_val_to_type(node)] + elif return_count == 0: + # Some torch ops do have 0 returns, and these are supported with ZeroResults + # op trait. Python bindings for IR creation allow us to pass empty result_types + # for such ops. Therefore, we pass an empty result types for these cases. + result_types = [] + else: + # Multi-return will unpack the meta["val"] and trigger our getitem subscripting + # short-circuit above. Note that if we ever choose to also fully reify Python + # level result tuples, we will need to create a tuple-boxed version of this and + # redirect to it for generic object access. + + result_types = [] + for v in node.meta["val"]: + result_types.append(self._cc.tensor_metadata_to_type(v)) + result_types = tuple(result_types) + + self._multi_result_nodes.add(node) + # Unroll operands from formal parameters, args and kwargs. + operands = [] + for i, parameter in enumerate(schema.arguments): + if parameter.kwarg_only and parameter.name in node.kwargs: + operands.append( + self._import_argument( + loc, node.kwargs[parameter.name], parameter.type + ) + ) + elif i < len(node.args): + operands.append( + self._import_argument(loc, node.args[i], parameter.type) + ) + else: + operands.append( + self._import_default_value( + loc, parameter.default_value, parameter.type + ) + ) + + # Support unregistered torch ops using torch.operator. + # torch.operator is used to represent ops from registry + # which haven't been generated by torch_ods_gen.py. + if not self._c.is_registered_operation(mlir_op_name): + operation = Operation.create( + "torch.operator", + attributes={"name": StringAttr.get(mlir_op_name)}, + results=result_types, + operands=operands, + loc=loc, + ) + else: + operation = Operation.create( + mlir_op_name, + results=result_types, + operands=operands, + loc=loc, + ) + + # Record value mapping. + for i, value in enumerate(operation.results): + self.bind_node_value(node, value, i) + + def _import_argument( + self, loc: Location, arg: NodeArgument, expected_jit_type=None + ) -> Value: + """Import an FX `Argument`, which must result to an MLIR `Value`.""" + if isinstance(arg, torch_fx.Node): + # If implementing boxed support for multi-result nodes, then + # this will need to do something more intelligent. + if arg in self._multi_result_nodes: + raise RuntimeError(f"Attempt to de-reference a multi-result node") + + # catch references to dynamically created constant attributes and make sure they have an origin in our module + if arg.op == "get_attr" and (arg.target, 0) not in self._v: + gm = arg.graph.owning_module + assert hasattr( + gm, arg.target + ), f"Attempting to retrieve attribute '{arg.target}' from module, but no such attribute exists" + obj = getattr(gm, arg.target) + with loc: + self.bind_node_value(arg, self._import_literal(obj)) + + return self.resolve_node_value(arg) + elif isinstance(arg, torch_fx.immutable_collections.immutable_list): + return self._import_list_argument(loc, arg, expected_jit_type) + elif isinstance(expected_jit_type, torch.TensorType) and not isinstance( + arg, torch.Tensor + ): + # promote scalars to tensor types as appropriate + return self._import_scalar_as_tensor(loc, arg) + else: + with loc: + return self._import_literal(arg) + + def _import_literal(self, py_value: Any) -> Value: + # Apply the conversion callback. + user_value = self.fx_importer._hooks.resolve_literal(self, py_value) + if user_value is not None: + assert isinstance(user_value, Value) + return user_value + + # Default conversion path. + converter = LITERAL_CONVERTER_MAP.lookup(type(py_value)) + if converter is None: + raise TypeError( + f"Unsupported argument -> literal conversion for {py_value.__class__}" + ) + return converter(py_value, self, self._cc) + + def _import_input(self, py_value: Any, info: InputInfo) -> Value: + # Try the hook. + user_value = self.fx_importer._hooks.resolve_input(self, py_value, info) + if user_value is not None: + assert isinstance(user_value, Value) + return user_value + + # Fall-back to treating as a literal if not mutating. + if info.mutable_producer_node_name is not None: + raise ValueError( + f"Cannot import {info.input_spec} as a literal because it is mutable" + ) + return self._import_literal(py_value) + + def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: + tensor_arg = torch.tensor(arg) + result_type = self._cc.get_vtensor_type(tensor_arg.size(), tensor_arg.dtype) + with loc: + constant_arg = LITERAL_CONVERTER_MAP.lookup(type(arg))(arg, self, self._cc) + + return Operation.create( + name="torch.prim.NumToTensor.Scalar", + results=[result_type], + operands=[constant_arg], + loc=loc, + ).result + + def _import_list_argument( + self, loc: Location, arg: NodeArgument, expected_jit_type + ) -> Value: + assert ( + isinstance(expected_jit_type, torch.ListType) + or ( + isinstance(expected_jit_type, torch.OptionalType) + and isinstance(expected_jit_type.getElementType(), torch.ListType) + ) + or isinstance(expected_jit_type, NoneType) + ), f"Unexpected jit type as list argument: {arg} of type {expected_jit_type}" + + # parse list type + if expected_jit_type is None: + element_type = type(arg[0]) + else: + element_jit_type = expected_jit_type.getElementType() + + # this branch is needed to handle Optional[List[]] types + if isinstance(element_jit_type, torch.ListType): + element_jit_type = element_jit_type.getElementType() + + # this handles getting the inner types for List[Optional[]] types + is_optional_type = isinstance(element_jit_type, torch.OptionalType) + if is_optional_type: + element_jit_type = element_jit_type.getElementType() + element_type = TORCH_TYPE_TO_PY_TYPE[type(element_jit_type)] + + # create list operands + list_operands = [] + + for operand in arg: + operand_type = type(operand) + if isinstance(operand, Node): + if operand in self._multi_result_nodes: + raise RuntimeError(f"Attempt to de-reference a multi-result node") + val = self.resolve_node_value(operand) + val_type = str(val.type) + assert ( + isinstance(element_type, str) and element_type in val_type + ) or SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get( + element_type + ) == val_type, f"Heterogeneous lists are not supported: expected {element_type}, got {val_type}" + else: + assert (is_optional_type and operand_type is NoneType) or ( + element_type == operand_type + ), f"Heterogeneous lists are not supported: expected {element_type}, got {operand_type}" + + operand_jit_type = ( + torch.NoneType if operand_type is NoneType else element_jit_type + ) + val = self._import_default_value(loc, operand, operand_jit_type) + + list_operands.append(val) + + # construct list op + if is_optional_type: + list_type = PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE[element_type] + else: + list_type = PY_TYPE_TO_TORCH_LIST_TYPE[element_type] + + result_type = IrType.parse(list_type, context=self._c) + operation = Operation.create( + "torch.prim.ListConstruct", + results=[result_type], + operands=list_operands, + loc=loc, + ) + + return operation.result + + def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: + """Imports a defaulted value for a known function schema.""" + if isinstance(arg, list): + return self._import_list_argument(loc, arg, expected_jit_type) + + # The LITERAL_CONVERTER_MAP maps each arg to its respective constant + # of the expected jit IR type (types like torch.dtype will form a chain of + # maps to get to constant of expected_jit_type). + cvt = LITERAL_CONVERTER_MAP.lookup(type(arg)) + if cvt is None: + raise RuntimeError(f"Unhandled default value ({arg.__class__}): {arg})") + with loc: + return cvt(arg, self, self._cc) + + +def _make_constant_op( + op_name: str, value_attr: Attribute, result_type: Optional[IrType] = None +) -> Operation: + return Operation.create( + op_name, + results=[result_type if result_type else value_attr.type], + attributes={"value": value_attr}, + ) + + +def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: + try: + dtype = tensor.dtype + element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() + tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) + return tensor_type + except KeyError: + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") + + +def _make_vtensor_literal_op( + tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker" +) -> Operation: + mapping = py_attr_tracker.track(tensor) + if mapping.is_empty: + # check support for bfloat16 + assert not ( + tensor.dtype == torch.bfloat16 and ml_dtypes is None + ), f"torch.bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + # Resolve the attribute. + npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype) + assert ( + npy_dtype is not None + ), f"Can not create literal tensor for unsupported datatype: {tensor.dtype}" + # We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal, + # but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get + # a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to + # detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw + # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as + # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) + np_tensor = np.array(tensor.tolist()).astype(npy_dtype) + # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not + # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling + # 0d tensors. + if np_tensor.size == 1: + try: + dtype = tensor.dtype + element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() + except KeyError: + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") + elements_attr = DenseElementsAttr.get( + type=element_type, array=np_tensor, shape=np_tensor.shape + ) + else: + bytes_view = np_tensor.view(npy_dtype) + tensor_type = create_mlir_tensor_type(tensor) + shape_desc = "_".join([str(d) for d in tensor.shape]) + blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" + elements_attr = DenseResourceElementsAttr.get_from_buffer( + bytes_view, + blob_name, + tensor_type, + ) + mapping.value = elements_attr + else: + elements_attr = mapping.value + return Operation.create( + name="torch.vtensor.literal", + results=[vtensor_type], + attributes={"value": elements_attr}, + ) + + +################################################################################ +# TypeSubclassMapping +################################################################################ + + +class TypeSubclassMap: + """Mapping of super-types to values. + + Maintains a cache of actual types seen and uses that instead of a linear + scan. + """ + + __slots__ = [ + "_cache", + "_mapping", + ] + + def __init__(self): + # The linear list of converters. + self._mapping: List[Tuple[type, Any]] = [] + # When there is a hit on the linear mapping, memoize it here. + self._cache: Dict[type, Any] = {} + + def map(self, t: type, value: Any): + self._mapping.append((t, value)) + self._cache[t] = value + + def lookup(self, t: type) -> Any: + try: + return self._cache[t] + except KeyError: + pass + for t_super, value in self._mapping: + if issubclass(t, t_super): + self._cache[t] = value + return value + else: + self._cache[t] = None + return None + + +############################################################################### +# Reference mapping +############################################################################### + + +# Opaque value to indicate something is empty. Used in cases where 'None' +# may have a different meaning. +class EmptyType: ... + + +Empty = EmptyType() + + +class RefMapping: + __slots__ = [ + "_referrent", + "value", + ] + + def __init__(self, referrent: Any): + if referrent is not Empty: + self._referrent = weakref.ref(referrent) + self.value = Empty + + @property + def is_empty(self): + return self.value is Empty + + def __repr__(self): + return ( + f" " + f"{self.value if self.value is not Empty else 'empty'}>" + ) + + +class RefTracker: + """Tracks live references from Python values to symbolic associations.""" + + def __init__(self): + self._refs: Dict[int, RefMapping] = {} + + def track(self, referrent: Any) -> RefMapping: + ref_id = id(referrent) + existing = self._refs.get(ref_id) + if existing: + return existing + info = RefMapping(referrent) + if referrent is not Empty: + weakref.finalize(referrent, self._ref_finalizer, ref_id) + self._refs[ref_id] = info + return info + + def _ref_finalizer(self, ref_id: int): + del self._refs[ref_id] + + +################################################################################ +# Mappings +################################################################################ + +LITERAL_CONVERTER_MAP = TypeSubclassMap() +LITERAL_CONVERTER_MAP.map( + NoneType, + lambda arg, gni, cc: Operation.create( + "torch.constant.none", results=[cc.torch_none_type] + ).result, +) +LITERAL_CONVERTER_MAP.map( + bool, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.bool", cc.integer_attr(arg, 1), cc.torch_bool_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + int, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.int", cc.integer_attr(arg, 64), cc.torch_int_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + float, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.float", FloatAttr.get_f64(arg), cc.torch_float_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + str, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.str", StringAttr.get(arg), cc.torch_str_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + torch.Tensor, + lambda arg, gni, cc: _make_vtensor_literal_op( + arg, cc.tensor_to_vtensor_type(arg), cc._py_attr_tracker + ).result, +) +LITERAL_CONVERTER_MAP.map( + torch.device, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.device", StringAttr.get(str(arg)), cc.torch_device_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + torch.dtype, + lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)( + TORCH_DTYPE_TO_INT[arg], gni, cc + ), +) +LITERAL_CONVERTER_MAP.map( + torch.layout, + lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)( + TORCH_LAYOUT_TO_INT[arg], gni, cc + ), +) +LITERAL_CONVERTER_MAP.map( + torch.memory_format, + lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)( + TORCH_MEMORY_FORMAT_TO_INT[arg], gni, cc + ), +) + +TORCH_TYPE_TO_PY_TYPE = { + torch.IntType: int, + torch.FloatType: float, + torch.StringType: str, + torch.BoolType: bool, + torch.TensorType: "vtensor", +} + +PY_TYPE_TO_TORCH_LIST_TYPE = { + int: "!torch.list", + float: "!torch.list", + str: "!torch.list", + bool: "!torch.list", + "tensor": "!torch.list", + "vtensor": "!torch.list", +} + +PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE = { + int: "!torch.list>", + float: "!torch.list>", + str: "!torch.list>", + bool: "!torch.list>", + "tensor": "!torch.list>", + "vtensor": "!torch.list>", +} + +SCALAR_TYPE_TO_TORCH_MLIR_TYPE = { + torch.SymInt: "!torch.int", + torch.SymFloat: "!torch.float", + torch.SymBool: "!torch.bool", + int: "!torch.int", + float: "!torch.float", + str: "!torch.str", + bool: "!torch.bool", + NoneType: "!torch.none", +} + + +# AOT-autograd sometimes falsely emit tensor version op with scalar arguments. +# We may remove this dictionary, if we fix such behavior in the backend. +TENSOR_SCALAR_OP_CONVERTER = { + "torch.aten.mul.Tensor": "torch.aten.mul.Scalar", + "torch.aten.div.Tensor": "torch.aten.div.Scalar", + "torch.aten.add.Tensor": "torch.aten.add.Scalar", + "torch.aten.sub.Tensor": "torch.aten.sub.Scalar", + "torch.aten.floor_divide": "torch.aten.floor_divide.Scalar", +} diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py new file mode 100644 index 000000000000..289e5722efce --- /dev/null +++ b/python/torch_mlir/extras/onnx_importer.py @@ -0,0 +1,767 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +"""Imports ONNX graphs to `torch` dialect ops. + +See documentation: + https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md + +This file is distributed/forked verbatim into various downstream projects, and +it must abide by several rules above and beyond the rest of the codebase: + - It must be standalone, only depending on: + - `onnx` + - `..ir` relative imports to the main IR directory + - `..dialects.func` relative import to the `func` dialect (TODO: + we are looking to eliminate this dep). + - Python standard library + - It does not directly use the ODS generated `torch` dialect Python + wrappers. This allows it to be used in contexts that only build a C++ + compiler with minimal IR Python bindings. + - It is intended as an enabler for full onnx compilation, only handling + the import from ONNX -> the `torch` dialect. Testing, full pipelines, + and utilities belong elsewhere. +""" + +try: + import onnx +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "The onnx package (`pip install onnx`) is required to use the onnx importer" + ) from e + +from typing import Optional + +from dataclasses import dataclass + +import numpy as np +import re + +from ..ir import ( + ArrayAttr, + Attribute, + Block, + Context, + DenseElementsAttr, + DenseResourceElementsAttr, + DictAttr, + FloatAttr, + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E5M2Type, + FunctionType, + InsertionPoint, + IntegerAttr, + IntegerType, + MLIRError, + RankedTensorType, + Location, + Module, + Operation, + StringAttr, + Type as IrType, + Value, +) + +from ..dialects import ( + func as func_dialect, +) + +@dataclass +class Config: + """Various configuration settings for the importer.""" + + # Ancient ONNX exporters would often add a model input for anything that + # might be mutable, providing an initializer for it as well. More modern + # tools tools realized this is a really bad idea for a lot of reasons. + # We choose to assume more recent norms, even if encountering older + # models. Setting this to False probably won't do what you want but + # should produce interesting errors to waste your time deciphering. + # We mainly use it as a way to document in the code that we are + # making an assumption. + elide_initialized_inputs: bool = True + + +class ModelInfo: + """Top-level accounting and accessors for an ONNX model.""" + + def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()): + self.config = config + self.model_proto = model_proto + assert model_proto.graph, "Model must contain a main Graph" + self.main_graph = GraphInfo(self, model_proto.graph) + + def create_module(self, context: Optional[Context] = None) -> Module: + if not context: + context = Context() + module = Module.create(Location.unknown(context)) + # TODO: Populate module level metadata from the ModelProto + return module + + +class GraphInfo: + """Information about a Graph within a model.""" + + def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): + self.model_info = model_info + self.graph_proto = graph_proto + self.initializer_map: dict[str, onnx.TensorProto] = { + n.name: n for n in graph_proto.initializer + } + self.value_info_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.value_info + } + self.declared_input_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.input + } + self.output_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.output + } + + # Generate the effective input map, which for old models can be a + # subset of the input map. + if model_info and model_info.config.elide_initialized_inputs: + self.input_map = { + k: v + for k, v in self.declared_input_map.items() + if k not in self.initializer_map + } + else: + self.input_map = self.declared_input_map + illegal_input_keys = self.input_map.keys() - ( + self.input_map.keys() - self.initializer_map.keys() + ) + assert self.input_map.keys().isdisjoint(self.initializer_map.keys()), ( + f"When not in elide_initialized_inputs=True, we expect inputs to not " + f"have an initial value (got {illegal_input_keys})." + ) + + def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: + # Node outputs don't typically have type information, but shape inference + # will associate them in the value_info. If not there, it may be a + # graph output, which must have type information. + value_info = self.value_info_map.get(name) or self.output_map.get(name) + if value_info is not None: + return value_info.type + # No type information is associated, this can occur when the value is unused: + return "" + + +class OnnxImportError(Exception): + ... + + +class NodeImporter: + """Imports graph nodes into MLIR. + + Typically, the top level graph will be imported into a func whereas dependent + graphs may just be imported with references to pre-existing values. + + Note that ONNX requires that graphs be sorted topologically and free of cycles, + so we don't take any special steps to order them for dominance. + """ + + __slots__ = [ + "_c", + "_cc", + "_gi", + "_p", + "_b", + "_nv_map", + ] + + def __init__( + self, + graph_info: GraphInfo, + *, + parent_op: Operation, + block: Block, + context_cache: "ContextCache", + ): + self._c = parent_op.context + self._cc = context_cache + self._gi = graph_info + self._p = parent_op + self._b = block + self._nv_map: dict[str, Value] = {} + + @classmethod + def define_function( + cls, graph_info: GraphInfo, module_op: Operation + ) -> "NodeImporter": + cc = ContextCache(module_op.context) + with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): + body = module_op.regions[0].blocks[0] + func_name = graph_info.graph_proto.name + input_types = [ + cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values() + ] + output_types = [ + cc.type_proto_to_type(out.type) + for out in graph_info.output_map.values() + ] + ftype = FunctionType.get(input_types, output_types) + func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body)) + block = func_op.add_entry_block( + [Location.name(k) for k in graph_info.input_map.keys()] + ) + imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc) + for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): + imp._nv_map[node_name] = input_value + imp._populate_graph_attrs(func_op) + return imp + + def _populate_graph_attrs(self, container_op: Operation): + """Populates graph level meta attributes on the given container op.""" + m = self._gi.model_info.model_proto + with container_op.context: + i64_type = IntegerType.get_signed(64) + default_opset_version = 0 + opset_versions: dict[str, IntegerAttr] = {} + for opset_import in m.opset_import: + if opset_import.domain: + opset_versions[opset_import.domain] = IntegerAttr.get( + i64_type, opset_import.version + ) + else: + default_opset_version = opset_import.version + if default_opset_version: + container_op.attributes[ + "torch.onnx_meta.opset_version" + ] = IntegerAttr.get(i64_type, default_opset_version) + if opset_versions: + container_op.attributes[ + "torch.onnx_meta.opset_versions" + ] = DictAttr.get(opset_versions) + container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( + IntegerType.get_signed(64), m.ir_version + ) + container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( + m.producer_name + ) + container_op.attributes[ + "torch.onnx_meta.producer_version" + ] = StringAttr.get(m.producer_version) + + def import_all(self, func=True): + """Imports all nodes topologically.""" + # TODO: Consider pulling in initializers on demand since there can be so + # much unused crap. + for init in self._gi.initializer_map.values(): + self.import_initializer(init) + + self.get_none() + for node in self._gi.graph_proto.node: + self.import_node(node) + + outputs = [] + for output_name in self._gi.output_map.keys(): + try: + outputs.append(self._nv_map[output_name]) + except KeyError: + raise OnnxImportError( + f"Non topologically produced ONNX graph output '{output_name}'" + ) + with InsertionPoint(self._b), Location.unknown(): + if func: + func_dialect.ReturnOp(outputs) + else: + Operation.create( + name="torch.operator_terminator", + operands=outputs) + + def get_none(self): + if '' in self._nv_map: + return self._nv_map[''] + + with InsertionPoint(self._b), Location.name("onnx_importer.none"): + nne = Operation.create( + name="torch.constant.none", + results=[self._cc.get_none_type()], + operands=[], + attributes={}, + ).results[0] + self._nv_map[''] = nne + return nne + + def import_node(self, node: onnx.NodeProto): + with InsertionPoint(self._b), Location.name(node.name): + op_type = node.op_type + # Handle special op types that materialize to non-op IR constructs. + # Handlers return True if the op was handled, else this function + # should process it as a general node. + special_key = f"_handle_node_{op_type}" + if hasattr(self, special_key): + was_handled = getattr(self, special_key)(node) + if was_handled: + return + # General node import. + input_values = [] + for input_name in node.input: + try: + input_values.append(self._nv_map[input_name]) + except KeyError: + raise OnnxImportError( + f"Non topologically produced ONNX node input '{input_name}': {node}" + ) + + output_names = list(node.output) + output_types = [ + self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n)) + for n in output_names + ] + + attrs = self.import_attributes(node.attribute) + attrs["name"] = StringAttr.get(f"onnx.{op_type}") + regions = self.count_regions(node.attribute) + + custom_op = Operation.create( + name="torch.operator", + results=output_types, + operands=input_values, + attributes=attrs, + regions=regions + ) + + self.import_regions(node.attribute, custom_op) + for output_name, output_value in zip(output_names, custom_op.results): + self._nv_map[output_name] = output_value + + def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]): + attrs = {} + for onnx_attr in onnx_attrs: + attr_type = onnx_attr.type + if attr_type not in ATTRIBUTE_TYPE_HANDLERS: + raise OnnxImportError( + f"Unhandled ONNX attribute type code {attr_type}: {onnx_attr}" + ) + handler = ATTRIBUTE_TYPE_HANDLERS[attr_type] + if handler is None: + # Active skip. + continue + elif handler is False: + # Active error. + raise OnnxImportError( + f"ONNX importer does not support generic node attribute type {attr_type}. " + f"This likely means that this is a special node which requires specific " + f"handling in the importer: {onnx_attr}" + ) + result = handler(onnx_attr, self._cc) + attrs[f"torch.onnx.{onnx_attr.name}"] = result + return attrs + + def count_regions(self, onnx_attrs: list[onnx.AttributeProto]): + count = 0 + for onnx_attr in onnx_attrs: + if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH: + count += 1 + return count + + def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op): + attr_map = {} + for onnx_attr in onnx_attrs: + attr_type = onnx_attr.type + if attr_type != onnx.AttributeProto.AttributeType.GRAPH: + continue + attr_map[onnx_attr.name] = onnx_attr + + for name, region in zip(sorted(attr_map.keys()), op.regions): + attr = attr_map[name] + block_types = [self._cc.type_proto_to_type(input.type) for input in attr.g.input] + block_names = [input.name for input in attr.g.input] + region.blocks.append(*block_types, arg_locs=[op.location] * len(block_types)) + block = region.blocks[0] + graph_info = GraphInfo(None, attr.g) + imp = NodeImporter(graph_info, parent_op=op, block=block, context_cache=self._cc) + + for node_name, input_value in zip(block_names, block.arguments): + imp._nv_map[node_name] = input_value + for k in self._nv_map: + imp._nv_map[k] = self._nv_map[k] + + imp.import_all(False) + + def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value: + # If an explicitly specified name is given, use that; otherwise, pick + # up the name from the tensor proto itself + iname = extern_name if extern_name else initializer.name + with InsertionPoint(self._b), Location.name(iname): + value_attr = self._cc.tensor_proto_to_attr(initializer) + vtensor_type = self._cc.tensor_proto_to_type(initializer) + attrs = { + "name": StringAttr.get(f"onnx.Constant"), + "torch.onnx.value": value_attr, + } + literal_op = Operation.create( + name="torch.operator", + results=[vtensor_type], + attributes=attrs, + ) + self._nv_map[iname] = literal_op.result + return literal_op.result + + def _get_immediate_tensor(self, name: str) -> np.array: + try: + initializer = self._gi.initializer_map[name] + except KeyError: + raise OnnxImportError( + f"An immediate value for '{name}' was required but it is dynamically produced." + ) + try: + dtype = ELEM_TYPE_TO_NUMPY_DTYPE[initializer.data_type] + except KeyError: + raise OnnxImportError( + f"Unknown ONNX tensor element type to numpy dtype mapping: {initializer.data_type}" + ) + raw_data = initializer.raw_data + if raw_data: + return np.frombuffer(raw_data, dtype=dtype).reshape(tuple(initializer.dims)) + else: + raise OnnxImportError( + f"Unhandled ONNX TensorProto immediate data: {initializer}" + ) + + def _handle_node_Constant(self, node: onnx.NodeProto) -> bool: + # Special case only for constants specified by value attribute (for now) + value_proto = _get_attr(node, "value", False) + if not value_proto: + return False + + # Produce an initializer for the constant, so that it can be used in + # combination with other ops, such as ConstantOfShape, requiring + # a constant input + assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR + assert len(node.output) == 1 + const_name = node.output[0] + self.import_initializer(value_proto.t, const_name) + self._gi.initializer_map[const_name] = value_proto.t + return True + +class ContextCache: + """Caches per-context lookups of various things.""" + + __slots__ = [ + "_c", + "_elem_type_map", + "_list_type_map", + "_optional_type_map", + "_vtensor_type_map", + ] + + def __init__(self, context: Context): + self._c = context + self._elem_type_map: dict[int, IrType] = {} + self._list_type_map:dict[str, IrType] = {} + self._optional_type_map:dict[str, IrType] = {} + self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {} + + def tensor_element_type(self, elem_type: int) -> IrType: + t = self._elem_type_map.get(elem_type) + if t is None: + try: + with self._c: + t = ELEM_TYPE_TO_IR_TYPE_CB[elem_type]() + except KeyError: + raise OnnxImportError(f"Unknown ONNX tensor element type: {elem_type}") + self._elem_type_map[elem_type] = t + return t + + def get_none_type(self): + return IrType.parse("!torch.none", context=self._c) + + def get_list_type(self, element_type: IrType) -> IrType: + key = str(element_type) + t = self._list_type_map.get(key) + if t is None: + asm = f"!torch.list<{str(element_type)}>" + try: + t = IrType.parse(asm, context=self._c) + except MLIRError as e: + raise OnnxImportError( + f"Unparseable torch type (MLIR asm format bug?): {asm}" + ) from e + self._list_type_map[key] = t + return t + + + def get_optional_type(self, element_type: IrType) -> IrType: + key = str(element_type) + t = self._optional_type_map.get(key) + if t is None: + asm = f"!torch.optional<{str(element_type)}>" + try: + t = IrType.parse(asm, context=self._c) + except MLIRError as e: + raise OnnxImportError( + f"Unparseable torch type (MLIR asm format bug?): {asm}" + ) from e + self._optional_type_map[key] = t + return t + + + def get_list_element_type(self, tp: onnx.TypeProto) -> IrType: + tt = tp.tensor_type + if tt.elem_type: + element_type = self.tensor_element_type(tt.elem_type) + dims = tuple( + (d.dim_value if not d.dim_param else None) for d in tt.shape.dim + ) + shape_asm = ",".join("?" if d is None else str(d) for d in dims) + return f"vtensor<[{shape_asm}],{element_type}>" + + raise OnnxImportError( + f"Unsupport list element type") + + def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType: + st = tp.sequence_type + tt = tp.tensor_type + if tt.elem_type: + element_type = self.tensor_element_type(tt.elem_type) + dims = tuple( + (d.dim_value if not d.dim_param else None) for d in tt.shape.dim + ) + shape_asm = ",".join("?" if d is None else str(d) for d in dims) + return f"vtensor<[{shape_asm}],{element_type}>" + + if st.elem_type: + element_type = self.get_list_element_type(st.elem_type) + return f"list<{element_type}>" + + raise OnnxImportError( + f"Unsupport optional element type") + + def get_vtensor_type( + self, dims: tuple[Optional[int]], element_type: IrType + ) -> IrType: + key = (dims, element_type) + t = self._vtensor_type_map.get(key) + if t is None: + shape_asm = ",".join("?" if d is None else str(d) for d in dims) + asm = f"!torch.vtensor<[{shape_asm}],{str(element_type)}>" + try: + t = IrType.parse(asm, context=self._c) + except MLIRError as e: + raise OnnxImportError( + f"Unparseable torch type (MLIR asm format bug?): {asm}" + ) from e + self._vtensor_type_map[key] = t + return t + + def tensor_proto_to_type(self, tp: onnx.TensorProto) -> IrType: + element_type = self.tensor_element_type(tp.data_type) + return self.get_vtensor_type(tuple(tp.dims), element_type) + + def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: + element_type = self.tensor_element_type(tp.data_type) + # TODO: Fixme upstream: RankedTensorType.get should not require a location. + with Location.unknown(): + try: + return RankedTensorType.get(tuple(tp.dims), element_type) + except TypeError as e: + raise OnnxImportError( + f"Unsupported builtin tensor type" + ) from e + + + def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: + if tp == "": + return self.get_none_type() + + tt = tp.tensor_type + if tt.elem_type: + if not tt.shape: + raise OnnxImportError( + f"Unsupported Tensor type without shape (run shape inference?): {tp}" + ) + element_type = self.tensor_element_type(tt.elem_type) + dims = tuple( + (d.dim_value if not d.dim_param else None) for d in tt.shape.dim + ) + return self.get_vtensor_type(dims, element_type) + + st = tp.sequence_type + if len(str(st.elem_type)) > 0: + element_type = self.get_list_element_type(st.elem_type) + return self.get_list_type(element_type) + + ot = tp.optional_type + if len(str(ot.elem_type)) > 0: + element_type = self.get_optional_element_type(ot.elem_type) + return self.get_optional_type(element_type) + + # TODO: Others if ever needed. Or we consider ourselves DNN-only. + # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. + raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") + + def _sanitize_name(self, name): + if not name.isidentifier(): + name = "_" + name + return re.sub("[:/]", "_", name) + + def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: + tensor_type = self.tensor_proto_to_builtin_type(tp) + if tp.HasField("raw_data"): + # Conveniently, DenseResourceElementsAttr shares the raw data + # format. We just give it maximum numeric alignment. + resource = DenseResourceElementsAttr.get_from_buffer( + tp.raw_data, self._sanitize_name(tp.name), tensor_type, alignment=8 + ) + return resource + else: + # We have to do a data type specific instantiation from proto fields. + # Since this is typically used for small tensor constants, we instantiate + # as a DenseElementsAttr. + handler = ELEM_TYPE_INLINE_TENSOR_PROTO_CB.get(tp.data_type) + if handler is None: + raise OnnxImportError(f"Unhandled ONNX TensorProto data: {tp}") + return handler(tp) + + +ELEM_TYPE_TO_IR_TYPE_CB = { + onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(), + onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8), + onnx.TensorProto.DataType.INT8: lambda: IntegerType.get_signed(8), + onnx.TensorProto.DataType.UINT16: lambda: IntegerType.get_unsigned(16), + onnx.TensorProto.DataType.INT16: lambda: IntegerType.get_signed(16), + onnx.TensorProto.DataType.INT32: lambda: IntegerType.get_signed(32), + onnx.TensorProto.DataType.INT64: lambda: IntegerType.get_signed(64), + onnx.TensorProto.DataType.BOOL: lambda: IntegerType.get_signless(1), + onnx.TensorProto.DataType.FLOAT16: lambda: F16Type.get(), + onnx.TensorProto.DataType.DOUBLE: lambda: F64Type.get(), + onnx.TensorProto.DataType.UINT32: lambda: IntegerType.get_unsigned(32), + onnx.TensorProto.DataType.UINT64: lambda: IntegerType.get_unsigned(64), + onnx.TensorProto.DataType.COMPLEX64: lambda: ComplexType.get(F32Type.get()), + onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()), + onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(), + onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(), + onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(), + onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(), + onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(), + onnx.TensorProto.DataType.STRING: lambda: "!torch.str", + # Ommitted: STRING, +} + +ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = { + onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat( + RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0]) + ), + onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat( + RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get( + IntegerType.get_signed(64), int.from_bytes(tp.raw_data, "little", + signed=True) if tp.HasField("raw_data") else tp.int64_data[0]) + ), + # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB +} + +# Mapping of TensorProto.DataType to lambda TensorProto, returning a DenseElementsAttr +# of the builtin tensor type for cases where the tensor data is inlined as typed +# values instead of raw_data. +ELEM_TYPE_INLINE_TENSOR_PROTO_CB = { + onnx.TensorProto.DataType.FLOAT: lambda tp: DenseElementsAttr.get( + np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.BOOL: lambda tp: DenseElementsAttr.get( + np.packbits(np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims), + axis=None, bitorder="little"), signless=False + ), + onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT16: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.int16).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT32: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.int32).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT64: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int64_data, dtype=np.int64).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.DOUBLE: lambda tp: DenseElementsAttr.get( + np.asarray(tp.double_data, dtype=np.float64).reshape(tp.dims) + ), + onnx.TensorProto.DataType.UINT32: lambda tp: DenseElementsAttr.get( + # Special case. See proto + np.asarray(tp.uint64_data, dtype=np.uint32).reshape(tp.dims), + signless=False, + ), + onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( + np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False + ) + # Intentionally unsupported: STRING +} + +ELEM_TYPE_TO_NUMPY_DTYPE = { + onnx.TensorProto.DataType.FLOAT: np.float32, + onnx.TensorProto.DataType.UINT8: np.uint8, + onnx.TensorProto.DataType.INT8: np.int8, + onnx.TensorProto.DataType.UINT16: np.uint16, + onnx.TensorProto.DataType.INT16: np.int16, + onnx.TensorProto.DataType.INT32: np.int32, + onnx.TensorProto.DataType.INT64: np.int64, + onnx.TensorProto.DataType.BOOL: np.bool_, + onnx.TensorProto.DataType.FLOAT16: np.float16, + onnx.TensorProto.DataType.DOUBLE: np.float64, + onnx.TensorProto.DataType.UINT32: np.uint32, + onnx.TensorProto.DataType.UINT64: np.uint64, + onnx.TensorProto.DataType.COMPLEX64: np.complex64, + onnx.TensorProto.DataType.COMPLEX128: np.complex128, + # onnx.TensorProto.DataType.BFLOAT16: + # onnx.TensorProto.DataType.FLOAT8E4M3FN: + # onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: + # onnx.TensorProto.DataType.FLOAT8E5M2: + # onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: + # Ommitted: STRING, +} + +# Mapping of AttributeType code to one of: +# None: Ignore attribute and do not output to MLIR +# False: Error if an attribute of this type is present +# lambda a:AttributeProto, cc: ContextCache that returns an MLIR Attribute +ATTRIBUTE_TYPE_HANDLERS = { + onnx.AttributeProto.AttributeType.UNDEFINED: False, + onnx.AttributeProto.AttributeType.FLOAT: lambda a, cc: FloatAttr.get( + F32Type.get(), a.f + ), + onnx.AttributeProto.AttributeType.INT: lambda a, cc: IntegerAttr.get( + IntegerType.get_signed(64), a.i + ), + onnx.AttributeProto.AttributeType.STRING: lambda a, cc: StringAttr.get(a.s), + onnx.AttributeProto.AttributeType.TENSOR: lambda a, cc: cc.tensor_proto_to_attr( + a.t + ), + onnx.AttributeProto.AttributeType.GRAPH: None, + onnx.AttributeProto.AttributeType.SPARSE_TENSOR: False, + onnx.AttributeProto.AttributeType.TYPE_PROTO: False, + onnx.AttributeProto.AttributeType.FLOATS: lambda a, cc: ArrayAttr.get( + [FloatAttr.get(F32Type.get(), f) for f in a.floats] + ), + onnx.AttributeProto.AttributeType.INTS: lambda a, cc: ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signed(64), i) for i in a.ints] + ), + onnx.AttributeProto.AttributeType.STRINGS: lambda a, cc: ArrayAttr.get( + [StringAttr.get(s) for s in a.strings] + ), + onnx.AttributeProto.AttributeType.TENSORS: lambda a, cc: ArrayAttr.get( + [cc.tensor_proto_to_attr(t) for t in a.tensors] + ), + onnx.AttributeProto.AttributeType.GRAPHS: False, + onnx.AttributeProto.AttributeType.SPARSE_TENSORS: False, + onnx.AttributeProto.AttributeType.TYPE_PROTOS: False, +} + + +def _get_attr(node: onnx.NodeProto, attr_name: str, is_required: bool = True) -> onnx.AttributeProto: + for attr in node.attribute: + if attr.name == attr_name: + return attr + if is_required: + raise OnnxImportError(f"Required attribute {attr_name} not found in {node}") + return None diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py new file mode 100644 index 000000000000..3622efafd9d2 --- /dev/null +++ b/python/torch_mlir/fx.py @@ -0,0 +1,43 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from typing import Optional + +import warnings + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_d +from torch_mlir.extras.fx_decomp_util import get_decomposition_table + +def export_and_import( + f, + *args, + fx_importer: Optional[FxImporter] = None, + experimental_support_mutation: bool = False, + hooks: Optional[FxImporterHooks] = None, + func_name: str = "main", + **kwargs, +): + context = ir.Context() + torch_d.register_dialect(context) + + if fx_importer is None: + fx_importer = FxImporter(context=context, hooks=hooks) + prog = torch.export.export(f, args, kwargs) + decomp_table = get_decomposition_table() + prog = prog.run_decompositions(decomp_table) + if experimental_support_mutation: + if torch.__version__ < "2.3.0.dev20240207": + warnings.warn("Mutable program import only supported on PyTorch 2.3+") + fx_importer.import_program(prog, func_name=func_name) + else: + fx_importer.import_frozen_program(prog, func_name=func_name) + + return fx_importer.module_op diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py new file mode 100644 index 000000000000..547fe5339dad --- /dev/null +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -0,0 +1,171 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +"""Console tool for converting an ONNX proto to torch IR. + +Typically, when installed from a wheel, this can be invoked as: + + torch-mlir-import-onnx some.pb + +Or from Python: + + python -m torch_mlir.tools.import_onnx ... +""" +import argparse +import os +from pathlib import Path +import shutil +import sys + +import onnx + +from ...extras import onnx_importer + +from ...dialects import torch as torch_d +from ...ir import ( + Context, +) + + +def main(args: argparse.Namespace): + model_proto = load_onnx_model(args) + context = Context() + torch_d.register_dialect(context) + model_info = onnx_importer.ModelInfo(model_proto) + m = model_info.create_module(context=context).operation + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) + imp.import_all() + if not args.no_verify: + m.verify() + + # TODO: This isn't very efficient output. If these files ever + # get large, enable bytecode and direct binary emission to save + # some copies. + if args.output_file and args.output_file != "-": + with open(args.output_file, "wt") as f: + print(m.get_asm(assume_verified=not args.no_verify), file=f) + else: + print(m.get_asm(assume_verified=not args.no_verify)) + + +def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: + # Do shape inference two ways. First, attempt in-memory to avoid redundant + # loading and the need for writing a temporary file somewhere. If that + # fails, typically because of the 2 GB protobuf size limit, try again via + # files. See + # https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#shape-inference-a-large-onnx-model-2gb + # for details about the file-based technique. + + # Make a temp dir for all the temp files we'll be generating as a side + # effect of infering shapes. For now, the only file is a new .onnx holding + # the revised model with shapes. + # + # TODO: If the program temp_dir is None, we should be using an ephemeral + # temp directory instead of a hard-coded path in order to avoid data races + # by default. + input_dir = os.path.dirname(os.path.abspath(args.input_file)) + temp_dir = ( + Path(input_dir if args.temp_dir is None else args.temp_dir) + / "onnx-importer-temp" + ) + shutil.rmtree(temp_dir, ignore_errors=True) + temp_dir.mkdir(exist_ok=True) + + # Load the model, with possible external data coming from the default + # location, or the location specified on the conmand line. + if args.data_dir is None: + raw_model = onnx.load(args.input_file) + else: + raw_model = onnx.load(args.input_file, load_external_data=False) + onnx.load_external_data_for_model(raw_model, args.data_dir) + + # Run the checker to test whether the file is above the threshold for + # in-memory shape inference. If not, go ahead and do the shape inference. + try: + onnx.checker.check_model(raw_model) + inferred_model = onnx.shape_inference.infer_shapes(raw_model) + return inferred_model + except ValueError: + pass + + # The following code was an attempt to work around the bug where models + # with external data produce invalid output shapes after infer_shapes_path. + # It works with small models but threw an error for llama seeming to + # indicate that the protobuf is corrupt. + # + # temp_raw_file = temp_dir / "raw.onnx" + # onnx.save(raw_model, temp_raw_file, save_as_external_data=False) + # onnx.shape_inference.infer_shapes_path(temp_raw_file, temp_inferred_file) + # inferred_model = onnx.load(temp_inferred_file) + + # Model is too big for in-memory inference: do file-based shape inference + # to a temp file. + temp_inferred_file = temp_dir / "inferred.onnx" + onnx.shape_inference.infer_shapes_path(args.input_file, temp_inferred_file) + + # Sanity check the shape-inferred model to be sure we have a good model + # for the importer. This call uses the file-based method, as the + # in-memory method (passing the loaded model) fails due to the 2 GB limit. + # + # TODO: this call throws an exception because it can't find the external + # data files, and there doesn't appear to be a way to let the checker know + # where to find them. + # + # onnx.checker.check_model(temp_inferred_file) + + # Load the temp file and the external data. + inferred_model = onnx.load(temp_inferred_file, load_external_data=False) + data_dir = Path(input_dir if args.temp_dir is None else args.data_dir) + onnx.load_external_data_for_model(inferred_model, data_dir) + + # Remove the inferred shape file unless asked to keep it + if not args.keep_temps: + shutil.rmtree(temp_dir) + + return inferred_model + + +def parse_arguments(argv=None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool") + parser.add_argument("input_file", help="ONNX protobuf input", type=Path) + parser.add_argument( + "-o", dest="output_file", help="Output path (or '-' for stdout)" + ) + parser.add_argument( + "--no-verify", + action="store_true", + help="Disable verification prior to printing", + ) + parser.add_argument( + "--keep-temps", action="store_true", help="Keep intermediate files" + ) + parser.add_argument( + "--temp-dir", + help="Pre-existing directory in which to create temporary files." + ' For example, to place temporaries under the directory "foo/bar",' + ' specify --temp-dir=foo/bar. "foo/bar" must already exist.' + " Defaults to the directory of the input file.", + type=Path, + ) + parser.add_argument( + "--data-dir", + help="Path between CWD and the base directory of the data," + " excluding the directories given in the 'location' argument of " + " convert_model_to_external_data. For example, if 'location' was" + ' "data/data.bin" and the relative path from CWD to that .bin file is' + ' a/b/data/data.bin, then set data-dir to "a/b".' + " Defaults to the directory of the input file.", + type=Path, + ) + args = parser.parse_args(argv) + return args + + +def _cli_main(): + sys.exit(main(parse_arguments())) + + +if __name__ == "__main__": + _cli_main() diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 2caf78c61ce4..a5e23f46ea17 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -a111e45dfe64cd565b2c0369b683f67d6658d2cc +ce013333221ff2d1285a8e8cf7c427584e65fea2 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index ca574a655eac..3ab13460e59a 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -4,4 +4,4 @@ # release page, and we use this page as an additional source for the wheels. -f https://xilinx.github.io/torch-mlir/package-index/ --pre -torch==2.3.0.dev20240108 +torch==2.3.0.dev20240307 diff --git a/setup.py b/setup.py index 9ed35b08da33..4863a9807522 100644 --- a/setup.py +++ b/setup.py @@ -30,74 +30,123 @@ # on the CMake side to organize that directory already, so we avoid duplicating # that here, and just package up its contents. import os +import pathlib import shutil import subprocess import sys -import sysconfig +import multiprocessing from distutils.command.build import build as _build -from distutils.sysconfig import get_python_inc from setuptools import setup, Extension from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py +def check_env_flag(name: str, default=None) -> bool: + return str(os.getenv(name, default)).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1" # If true, enable LTC build by default TORCH_MLIR_ENABLE_LTC_DEFAULT = True -TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False)) -if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: - import torch +TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = check_env_flag( + 'TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False) +LLVM_INSTALL_DIR = os.getenv('LLVM_INSTALL_DIR', None) +SRC_DIR = pathlib.Path(__file__).parent.absolute() +CMAKE_BUILD_TYPE = os.getenv("CMAKE_BUILD_TYPE", "Release") + # Build phase discovery is unreliable. Just tell it what phases to run. class CustomBuild(_build): + def initialize_options(self): + _build.initialize_options(self) + # Make setuptools not steal the build directory name, + # because the mlir c++ developers are quite + # used to having build/ be for cmake + self.build_base = "setup_build" + def run(self): self.run_command("build_py") self.run_command("build_ext") self.run_command("build_scripts") + class CMakeBuild(build_py): + def cmake_build(self, cmake_build_dir): + llvm_dir = str(SRC_DIR / "externals" / "llvm-project" / "llvm") + enable_ltc = check_env_flag('TORCH_MLIR_ENABLE_LTC', TORCH_MLIR_ENABLE_LTC_DEFAULT) + max_jobs = os.getenv("MAX_JOBS") or str(multiprocessing.cpu_count()) + + cmake_config_args = [ + f"cmake", + f"-DCMAKE_BUILD_TYPE={CMAKE_BUILD_TYPE}", + f"-DPython3_EXECUTABLE={sys.executable}", + f"-DPython3_FIND_VIRTUALENV=ONLY", + f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON", + f"-DLLVM_TARGETS_TO_BUILD=host", + f"-DLLVM_ENABLE_ZSTD=OFF", + # Optimization options for building wheels. + f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON", + f"-DCMAKE_C_VISIBILITY_PRESET=hidden", + f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", + f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", + f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}", + ] + if LLVM_INSTALL_DIR: + cmake_config_args += [ + f"-DMLIR_DIR='{LLVM_INSTALL_DIR}/lib/cmake/mlir/'", + f"-DLLVM_DIR='{LLVM_INSTALL_DIR}/lib/cmake/llvm/'", + f"{SRC_DIR}", + ] + else: + cmake_config_args += [ + f"-DLLVM_ENABLE_PROJECTS=mlir", + f"-DLLVM_EXTERNAL_PROJECTS='torch-mlir'", + f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={SRC_DIR}", + f"{llvm_dir}", + ] + cmake_build_args = [ + f"cmake", + f"--build", + f".", + f"--config", + f"{CMAKE_BUILD_TYPE}", + f"--target", + f"TorchMLIRPythonModules", + f"--", + f"-j{max_jobs}" + ] + try: + subprocess.check_call(cmake_config_args, cwd=cmake_build_dir) + subprocess.check_call(cmake_build_args, cwd=cmake_build_dir) + except subprocess.CalledProcessError as e: + print("cmake build failed with\n", e) + print("debug by follow cmake command:") + sys.exit(e.returncode) + finally: + print(f"cmake config: {' '.join(cmake_config_args)}") + print(f"cmake build: {' '.join(cmake_build_args)}") + print(f"cmake workspace: {cmake_build_dir}") + + def run(self): target_dir = self.build_lib cmake_build_dir = os.getenv("TORCH_MLIR_CMAKE_BUILD_DIR") - custom_python_package_path = os.getenv("TORCH_MLIR_PYTHON_PACKAGE_DIR",None) if not cmake_build_dir: cmake_build_dir = os.path.abspath( os.path.join(target_dir, "..", "cmake_build")) - if custom_python_package_path is not None and os.path.isdir(custom_python_package_path): - python_package_dir = custom_python_package_path + if LLVM_INSTALL_DIR: + python_package_dir = os.path.join(cmake_build_dir, + "python_packages", + "torch_mlir") else: python_package_dir = os.path.join(cmake_build_dir, - "tools", "torch-mlir", "python_packages", - "torch_mlir") + "tools", "torch-mlir", "python_packages", + "torch_mlir") if not os.getenv("TORCH_MLIR_CMAKE_BUILD_DIR_ALREADY_BUILT"): - src_dir = os.path.abspath(os.path.dirname(__file__)) - llvm_dir = os.path.join( - src_dir, "externals", "llvm-project", "llvm") - - enable_ltc = int(os.environ.get('TORCH_MLIR_ENABLE_LTC', TORCH_MLIR_ENABLE_LTC_DEFAULT)) - - cmake_args = [ - f"-DCMAKE_BUILD_TYPE=Release", - f"-DPython3_EXECUTABLE={sys.executable}", - f"-DPython3_FIND_VIRTUALENV=ONLY", - f"-DLLVM_TARGETS_TO_BUILD=host", - f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON", - f"-DLLVM_ENABLE_PROJECTS=mlir", - f"-DLLVM_ENABLE_ZSTD=OFF", - f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir", - f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}", - # Optimization options for building wheels. - f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON", - f"-DCMAKE_C_VISIBILITY_PRESET=hidden", - f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", - f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", - f"-DTORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS={'ON' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'OFF'}", - ] - os.makedirs(cmake_build_dir, exist_ok=True) cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt") if os.path.exists(cmake_cache_file): @@ -115,14 +164,7 @@ def run(self): shutil.rmtree(mlir_libs_dir) else: print(f"Not removing _mlir_libs dir (does not exist): {mlir_libs_dir}") - - subprocess.check_call(["cmake", llvm_dir] + - cmake_args, cwd=cmake_build_dir) - subprocess.check_call(["cmake", - "--build", ".", - "--config", "Release", - "--target", "TorchMLIRPythonModules"], - cwd=cmake_build_dir) + self.cmake_build(cmake_build_dir) if os.path.exists(target_dir): shutil.rmtree(target_dir, ignore_errors=False, onerror=None) @@ -149,8 +191,31 @@ def build_extension(self, ext): long_description = fh.read() +# Requires and extension modules depend on whether building PyTorch +# extensions. +INSTALL_REQUIRES = [ + "numpy", + "packaging", +] +EXT_MODULES = [ + CMakeExtension("torch_mlir._mlir_libs._torchMlir"), +] +NAME = "torch-mlir-core" + +# If building PyTorch extensions, customize. +if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: + import torch + NAME = "torch-mlir" + INSTALL_REQUIRES.extend([ + f"torch=={torch.__version__}".split("+", 1)[0], + ]) + EXT_MODULES.extend([ + CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), + ]) + + setup( - name="torch-mlir" if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else "torch-mlir-core", + name=NAME, version=f"{PACKAGE_VERSION}", author="Sean Silva", author_email="silvasean@google.com", @@ -163,10 +228,17 @@ def build_extension(self, ext): "built_ext": NoopBuildExtension, "build_py": CMakeBuild, }, - ext_modules=[ - CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), - ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [CMakeExtension("torch_mlir._mlir_libs._torchMlir")], - install_requires=["numpy", "packaging"] + ( - [f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else []), + ext_modules=EXT_MODULES, + install_requires=INSTALL_REQUIRES, + extras_require={ + "onnx": [ + "onnx>=1.15.0", + ], + }, + entry_points={ + "console_scripts": [ + "torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main", + ], + }, zip_safe=False, ) diff --git a/test-requirements.txt b/test-requirements.txt index 0046a02f0d5e..b21e8dfcd021 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,4 +1,5 @@ pillow dill multiprocess +onnx==1.15.0 mpmath==1.3.0 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 397d72a4896b..8cd8bab0032f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s // Generally, the test cases accumulated here come from running the importer // over all included backend tests that involve simple ops with no model // level constants. This is a pragmatic choice which lets us have a lot @@ -11,6 +11,8 @@ func.func @test_abs(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_add func.func @test_add(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -19,6 +21,8 @@ func.func @test_add(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_add_bcast func.func @test_add_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -27,6 +31,8 @@ func.func @test_add_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_add_uint8 func.func @test_add_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -35,6 +41,8 @@ func.func @test_add_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],ui8> } +// ----- + // CHECK-LABEL: @test_and_bcast3v1d func.func @test_and_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> @@ -42,6 +50,8 @@ func.func @test_and_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.v return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: @test_argmax_default_axis_example func.func @test_argmax_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 0 @@ -51,6 +61,8 @@ func.func @test_argmax_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[1,2],si64> } +// ----- + // CHECK-LABEL: @test_argmax_negative_axis_keepdims_example func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -60,6 +72,8 @@ func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 return %0 : !torch.vtensor<[2,1],si64> } +// ----- + // CHECK-LABEL: @test_argmax_no_keepdims_example func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -69,6 +83,8 @@ func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[2],si64> } +// ----- + // CHECK-LABEL: @test_argmin_default_axis_example func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 0 @@ -78,6 +94,8 @@ func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[1,2],si64> } +// ----- + // CHECK-LABEL: @test_argmin_negative_axis_keepdims_example func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -87,6 +105,8 @@ func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 return %0 : !torch.vtensor<[2,1],si64> } +// ----- + // CHECK-LABEL: @test_argmin_no_keepdims_example func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -96,6 +116,8 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[2],si64> } +// ----- + // CHECK-LABEL: @test_atan func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -103,6 +125,17 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + +// CHECK-LABEL: @test_atanh +func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: @test_acos func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.acos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -110,6 +143,31 @@ func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + +// CHECK-LABEL: @test_bernoulli +func.func @test_bernoulli(%arg0: !torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %0 = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f64>, !torch.none -> !torch.vtensor<[10],f64> + %0 = torch.operator "onnx.Bernoulli"(%arg0) : (!torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> + return %0 : !torch.vtensor<[10],f64> +} + +// ----- + +// CHECK-LABEL: @test_bernoulli_double +func.func @test_bernoulli_double(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[BERNOULLI:.*]] = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: %[[DTYPE:.*]] = torch.constant.int 7 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %[[BERNOULLI]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f64> + %0 = torch.operator "onnx.Bernoulli"(%arg0) {torch.onnx.dtype = 11 : si64} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> + return %0 : !torch.vtensor<[10],f64> +} + +// ----- + // CHECK-LABEL: @test_bitshift_left_uint8 func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> @@ -117,6 +175,8 @@ func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torc return %0 : !torch.vtensor<[3],ui8> } +// ----- + // CHECK-LABEL: @test_bitshift_left_uint16 func.func @test_bitshift_left_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16> @@ -124,6 +184,8 @@ func.func @test_bitshift_left_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !to return %0 : !torch.vtensor<[3],ui16> } +// ----- + // CHECK-LABEL: @test_bitshift_left_uint32 func.func @test_bitshift_left_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32> @@ -131,6 +193,8 @@ func.func @test_bitshift_left_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !to return %0 : !torch.vtensor<[3],ui32> } +// ----- + // CHECK-LABEL: @test_bitshift_left_uint64 func.func @test_bitshift_left_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64> @@ -138,6 +202,8 @@ func.func @test_bitshift_left_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !to return %0 : !torch.vtensor<[3],ui64> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint8 func.func @test_bitshift_right_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> @@ -145,6 +211,8 @@ func.func @test_bitshift_right_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !tor return %0 : !torch.vtensor<[3],ui8> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint16 func.func @test_bitshift_right_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16> @@ -152,6 +220,8 @@ func.func @test_bitshift_right_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !t return %0 : !torch.vtensor<[3],ui16> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint32 func.func @test_bitshift_right_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32> @@ -159,6 +229,8 @@ func.func @test_bitshift_right_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !t return %0 : !torch.vtensor<[3],ui32> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint64 func.func @test_bitshift_right_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64> @@ -166,6 +238,8 @@ func.func @test_bitshift_right_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !t return %0 : !torch.vtensor<[3],ui64> } +// ----- + // CHECK-LABEL: @test_bitwise_and_i16_3d func.func @test_bitwise_and_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16> @@ -173,6 +247,8 @@ func.func @test_bitwise_and_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: ! return %0 : !torch.vtensor<[3,4,5],si16> } +// ----- + // CHECK-LABEL: @test_bitwise_and_i32_2d func.func @test_bitwise_and_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -180,6 +256,8 @@ func.func @test_bitwise_and_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !to return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_and_ui8_bcast_4v3d func.func @test_bitwise_and_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -187,6 +265,8 @@ func.func @test_bitwise_and_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_bitwise_or_i16_4d func.func @test_bitwise_or_i16_4d(%arg0: !torch.vtensor<[3,4,5,6],si8>, %arg1: !torch.vtensor<[3,4,5,6],si8>) -> !torch.vtensor<[3,4,5,6],si8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],si8>, !torch.vtensor<[3,4,5,6],si8> -> !torch.vtensor<[3,4,5,6],si8> @@ -194,6 +274,8 @@ func.func @test_bitwise_or_i16_4d(%arg0: !torch.vtensor<[3,4,5,6],si8>, %arg1: ! return %0 : !torch.vtensor<[3,4,5,6],si8> } +// ----- + // CHECK-LABEL: @test_bitwise_or_i32_2d func.func @test_bitwise_or_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -201,6 +283,8 @@ func.func @test_bitwise_or_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !tor return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_or_ui8_bcast_4v3d func.func @test_bitwise_or_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -208,6 +292,8 @@ func.func @test_bitwise_or_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_bitwise_not_2d func.func @test_bitwise_not_2d(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -215,6 +301,8 @@ func.func @test_bitwise_not_2d(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vten return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_not_4d func.func @test_bitwise_not_4d(%arg0: !torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -222,6 +310,8 @@ func.func @test_bitwise_not_4d(%arg0: !torch.vtensor<[3,4,5,6],ui8>) -> !torch.v return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_bitwise_xor_i16_3d func.func @test_bitwise_xor_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16> @@ -229,6 +319,8 @@ func.func @test_bitwise_xor_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: ! return %0 : !torch.vtensor<[3,4,5],si16> } +// ----- + // CHECK-LABEL: @test_bitwise_xor_i32_2d func.func @test_bitwise_xor_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -236,6 +328,8 @@ func.func @test_bitwise_xor_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !to return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_xor_ui8_bcast_4v3d func.func @test_bitwise_xor_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -243,6 +337,8 @@ func.func @test_bitwise_xor_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_cast_BFLOAT16_to_FLOAT func.func @test_cast_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -253,6 +349,8 @@ func.func @test_cast_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>) -> !to return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_cast_DOUBLE_to_FLOAT func.func @test_cast_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -263,6 +361,8 @@ func.func @test_cast_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>) -> !torch return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_cast_DOUBLE_to_FLOAT16 func.func @test_cast_DOUBLE_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 5 @@ -273,6 +373,8 @@ func.func @test_cast_DOUBLE_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f64>) -> !tor return %0 : !torch.vtensor<[3,4],f16> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT_to_BFLOAT16 func.func @test_cast_FLOAT_to_BFLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],bf16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 15 @@ -283,6 +385,8 @@ func.func @test_cast_FLOAT_to_BFLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !tor return %0 : !torch.vtensor<[3,4],bf16> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT_to_DOUBLE func.func @test_cast_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 7 @@ -293,6 +397,8 @@ func.func @test_cast_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>) -> !torch return %0 : !torch.vtensor<[3,4],f64> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT_to_FLOAT16 func.func @test_cast_FLOAT_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 5 @@ -303,6 +409,8 @@ func.func @test_cast_FLOAT_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torc return %0 : !torch.vtensor<[3,4],f16> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT16_to_DOUBLE func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 7 @@ -313,6 +421,20 @@ func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !tor return %0 : !torch.vtensor<[3,4],f64> } +// ----- + +// CHECK-LABEL: @test_cast_FLOAT_to_BOOL +func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 11 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> +} + +// ----- + // CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -323,6 +445,56 @@ func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torc return %0 : !torch.vtensor<[3,4],f32> } +// ----- + +// CHECK-LABEL: @test_castlike_BFLOAT16_to_FLOAT +func.func @test_castlike_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],bf16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_castlike_DOUBLE_to_FLOAT +func.func @test_castlike_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f64>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_castlike_FLOAT_to_DOUBLE +func.func @test_castlike_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 7 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> + return %0 : !torch.vtensor<[3,4],f64> +} + +// ----- + +// CHECK-LABEL: @test_castlike_FLOAT16_to_FLOAT +func.func @test_castlike_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + // CHECK-LABEL: @test_ceil_example func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> @@ -330,6 +502,8 @@ func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[ return %0 : !torch.vtensor<[2],f32> } +// ----- + // CHECK-LABEL: @test_ceil func.func @test_ceil(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -337,6 +511,8 @@ func.func @test_ceil(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_clip_default_int8_min func.func @test_clip_default_int8_min(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si8>, !torch.vtensor<[],si8> -> !torch.vtensor<[3,4,5],si8> @@ -344,6 +520,18 @@ func.func @test_clip_default_int8_min(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: return %0 : !torch.vtensor<[3,4,5],si8> } +// ----- + +// CHECK-LABEL: @test_clip_default_int8_max +func.func @test_clip_default_int8_max(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.clamp.Tensor %arg0, %none, %arg1 : !torch.vtensor<[3,4,5],si8>, !torch.none, !torch.vtensor<[],si8> -> !torch.vtensor<[3,4,5],si8> + %0 = torch.operator "onnx.Clip"(%arg0, %none, %arg1) : (!torch.vtensor<[3,4,5],si8>, !torch.none, !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> + return %0 : !torch.vtensor<[3,4,5],si8> +} + +// ----- + // CHECK-LABEL: @test_clip_default_min func.func @test_clip_default_min(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32> @@ -351,6 +539,8 @@ func.func @test_clip_default_min(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_clip_example func.func @test_clip_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3],f32> @@ -358,6 +548,8 @@ func.func @test_clip_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtens return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_clip func.func @test_clip(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32> @@ -365,6 +557,22 @@ func.func @test_clip(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + +module { + func.func @test_clip_attrs(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} { + %none = torch.constant.none + + // CHECK: %[[MIN:.+]] = torch.vtensor.literal(dense<-5.000000e-01> : tensor<3x4xf32>) : !torch.vtensor<[3,4],f32> + // CHECK: %[[MAX:.+]] = torch.vtensor.literal(dense<5.000000e-01> : tensor<3x4xf32>) : !torch.vtensor<[3,4],f32> + // CHECK: %[[CLAMP:.+]] = torch.aten.clamp.Tensor %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Clip"(%arg0) {torch.onnx.max = 5.000000e-01 : f32, torch.onnx.min = -5.000000e-01 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> + } +} + +// ----- + // CHECK-LABEL: @test_cos_example func.func @test_cos_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -372,6 +580,8 @@ func.func @test_cos_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_cos func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -379,6 +589,120 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + +// CHECK-LABEL: @test_cosh_example +func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_cosh +func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_acosh_example +func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_acosh +func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_asin_example +func.func @test_asin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Asin"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_asin +func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asin %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Asin"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_asinh_example +func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_asinh +func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_dequantizelinear_si8 +func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si8> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + +// CHECK-LABEL: @test_dequantizelinear_ui8 +func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + +// CHECK-LABEL: @test_dequantizelinear_i32 +func.func @test_dequantizelinear_i32(%arg0: !torch.vtensor<[6],si32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + + // CHECK-LABEL: @test_div_bcast func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -386,6 +710,8 @@ func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_div_example func.func @test_div_example(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> @@ -393,6 +719,8 @@ func.func @test_div_example(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[2],f32> } +// ----- + // CHECK-LABEL: @test_div func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -400,6 +728,17 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + +// CHECK-LABEL: @test_div_int32 +func.func @test_div_int32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],si32> + %0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> + return %0 : !torch.vtensor<[3,4,5],si32> +} + +// ----- + // CHECK-LABEL: @test_div_uint8 func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8> @@ -407,6 +746,8 @@ func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],ui8> } +// ----- + // CHECK-LABEL: @test_equal_bcast func.func @test_equal_bcast(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[5],si32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[5],si32> -> !torch.vtensor<[3,4,5],i1> @@ -414,6 +755,17 @@ func.func @test_equal_bcast(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.v return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + +// CHECK-LABEL: @test_erf +func.func @test_erf(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.erf %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Erf"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: @test_equal func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],i1> @@ -421,6 +773,8 @@ func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: @test_floor_example func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -428,9 +782,886 @@ func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_floor func.func @test_floor(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Floor"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: @test_averagepool_1d_default +func.func @test_averagepool_1d_default(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.avg_pool1d %arg0, %0, %2, %1, %false, %true : !torch.vtensor<[1,3,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32> + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64], torch.onnx.count_include_pad = 1 : si64} : (!torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> + return %0 : !torch.vtensor<[1,3,31],f32> +} + +// ----- + +// CHECK-LABEL: @test_averagepool_2d_ceil +func.func @test_averagepool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.avg_pool2d %arg0, %0, %2, %1, %true, %false, %none : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,2,2],f32> + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> + return %0 : !torch.vtensor<[1,1,2,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_averagepool_3d_default +func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false{{.*}}, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32> + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31,31],f32> +} + +// ----- + +// CHECK-LABEL: @test_averagepool_with_padding +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,20,64,48],f32> +// CHECK: torch.aten.avg_pool2d %[[ARG]], {{.*}} : !torch.vtensor<[1,20,64,48],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,20,32,24],f32> + +func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 19 : si64} { + + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32> + return %0 : !torch.vtensor<[1,20,32,24],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_strides_no_padding +func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,2],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> + return %0 : !torch.vtensor<[1,1,3,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_strides_padding +func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_bias_strides_padding +func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %arg2, %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[?,?,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,64,112,112],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1, %arg2) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,?,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> + return %0 : !torch.vtensor<[?,64,112,112],f32> +} + +// ----- + +// CHECK-LABEL: @test_convtranspose_dilations +func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.dilations = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> + return %0 : !torch.vtensor<[1,1,5,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_convtranspose +func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,5,5],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> + return %0 : !torch.vtensor<[1,2,5,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_convtranspose_pad + func.func @test_convtranspose_pad(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,10,8],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_padding = [1 : si64, 1 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> + return %0 : !torch.vtensor<[1,2,10,8],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_pads + func.func @test_convtranspose_pads(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,7,3],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> + return %0 : !torch.vtensor<[1,2,7,3],f32> + } + +// ----- + +// CHECK-LABEL: @test_batchnorm_epsilon +func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208 + // CHECK: %[[EPS:.*]] = torch.constant.float 0.0099999997764825821 + // CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32> + %0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 0.00999999977 : f32} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> + return %0 : !torch.vtensor<[2,3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_batchnorm_example +func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208 + // CHECK: %[[EPS:.*]] = torch.constant.float 9.9999997473787516E-6 + // CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32> + %0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> + return %0 : !torch.vtensor<[2,3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_1d_axis_0 +func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 0 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_1d_axis_negative_1 +func.func @test_concat_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_2d_axis_0 +func.func @test_concat_2d_axis_0(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 0 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> + return %0 : !torch.vtensor<[4,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_2d_axis_1 +func.func @test_concat_2d_axis_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_2d_axis_negative_1 +func.func @test_concat_2d_axis_negative_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_2d_axis_negative_2 +func.func @test_concat_2d_axis_negative_2(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -2 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> + return %0 : !torch.vtensor<[4,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_3d_axis_0 +func.func @test_concat_3d_axis_0(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 0 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> + return %0 : !torch.vtensor<[4,2,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_3d_axis_1 +func.func @test_concat_3d_axis_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> + return %0 : !torch.vtensor<[2,4,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_3d_axis_2 +func.func @test_concat_3d_axis_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 2 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> + return %0 : !torch.vtensor<[2,2,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_3d_axis_negative_1 +func.func @test_concat_3d_axis_negative_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> + return %0 : !torch.vtensor<[2,2,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_3d_axis_negative_2 +func.func @test_concat_3d_axis_negative_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -2 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> + return %0 : !torch.vtensor<[2,4,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_concat_3d_axis_negative_3 +func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -3 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -3 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> + return %0 : !torch.vtensor<[4,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_exp +func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} { + // CHECK: torch.aten.exp %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Exp"(%arg0) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_expand_dim2_shape2 +func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>) + -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_expand_dim2_shape3 +func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[I0_0:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I0_0]] + // CHECK-NEXT: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1 + // CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]] + // CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] + // CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]] + // CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2 + // CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]] + // CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]] + // CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1 + // CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]] + // CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]] + // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]] + // CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]] + // CHECK: return %[[EXPAND]] + %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> + return %0 : !torch.vtensor<[2,3,6],f32> +} + +// ----- + +// CHECK-LABEL: @test_dropout +func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32 + %0 = torch.operator "onnx.Dropout"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_dropout_default +func.func @test_dropout_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Dropout"(%arg0) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_dropout_default_mask +func.func @test_dropout_default_mask(%arg0: !torch.vtensor<[3,4,5],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.ones_like %arg0, %int11, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],i1> + %0:2 = torch.operator "onnx.Dropout"(%arg0) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) + return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: @test_dropout_default_mask_ratio +func.func @test_dropout_default_mask_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.ones_like %arg0, %int11, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],i1> + %0:2 = torch.operator "onnx.Dropout"(%arg0, %arg1) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) + return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: @test_dropout_default_ratio +func.func @test_dropout_default_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Dropout"(%arg0, %arg1) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_training_dropout_zero_ratio +func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],i1>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %0, %2 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Dropout"(%arg0, %arg1, %arg2) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],i1>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_elu_default +func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.elu %arg0, %float0.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Elu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_elu_example +func.func @test_elu_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.elu %arg0, %float2.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Elu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_depthtospace_example +func.func @test_depthtospace_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[DIV:.*]] = torch.aten.div.int %[[SIZE_0]], %[[C4]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[INT:.*]] = torch.aten.Int.float %[[DIV]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[C2_0]], %[[C2_0]], %[[INT]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,8,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[RESHAPE]], %[[C1_0]], %[[C3_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[TRANSPOSE_1:.*]] = torch.aten.transpose.int %[[TRANSPOSE]], %[[C2_1]], %[[C4_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C4_1:.*]] = torch.constant.int 4 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,4,6],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32 + %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "DCR"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> + return %0 : !torch.vtensor<[1,2,4,6],f32> +} + +// ----- + +// CHECK-LABEL: @test_depthtospace_crd_mode_example +func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[DIV:.*]] = torch.aten.div.int %[[SIZE_0]], %[[C4]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[INT:.*]] = torch.aten.Int.float %[[DIV]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[C2_0]], %[[C2_0]], %[[INT]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,8,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[RESHAPE]], %[[C2_1]], %[[C4_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C4_1:.*]] = torch.constant.int 4 + // CHECK: %[[TRANSPOSE_1:.*]] = torch.aten.transpose.int %[[TRANSPOSE]], %[[C3_0]], %[[C4_1]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C4_1:.*]] = torch.constant.int 4 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,4,6],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32 + %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> + return %0 : !torch.vtensor<[1,2,4,6],f32> +} + +// ----- + +// CHECK-LABEL: @float_constant +func.func @float_constant() -> !torch.vtensor<[], f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<2.500000e-01> : tensor) : !torch.vtensor<[],f32> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value_float = 0.25 : f32} : () -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: @int_constant +func.func @int_constant() -> !torch.vtensor<[], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<79> : tensor) : !torch.vtensor<[],si64> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value_int = 79 : si64} : () -> !torch.vtensor<[],si64> + return %0 : !torch.vtensor<[],si64> +} + +// ----- + +// CHECK-LABEL: @dense_constant +func.func @dense_constant() -> !torch.vtensor<[1], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<13> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<13> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @ints_constant +func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[7, 9]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + // CHECK: return %[[CST]] + %0 = "torch.operator"() <{name = "onnx.Constant"}> {torch.onnx.value_ints = [7 : si64, 9 : si64]} : () -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// ----- + +// CHECK-LABEL: @dense_constant +func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: torch.vtensor.literal(dense<[0, 10, 128, 17000]> : tensor<4xsi32>) : !torch.vtensor<[4],si32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int32> : tensor<4xsi32>} : () -> !torch.vtensor<[4],si32> + // CHECK: torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+01, 1.280000e+02, 1.700000e+04]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_float32> : tensor<4xf32>} : () -> !torch.vtensor<[4],f32> + // CHECK: torch.vtensor.literal(dense<[-128, -1, 50, 127]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int8> : tensor<4xsi8>} : () -> !torch.vtensor<[4],si8> + // CHECK: torch.vtensor.literal(dense<[128, 255, 50, 127]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int8> : tensor<4xui8>} : () -> !torch.vtensor<[4],ui8> + return +} + +{-# + dialect_resources: { + builtin: { + _int8: "0x0800000080FF327F", + _int32: "0x08000000000000000a0000008000000068420000", + _float32: "0x0800000000000000000020410000004300d08446" + } + } +#-} + +// ----- + +// CHECK-LABEL: @dense_constant_i1 +func.func @dense_constant_i1() -> !torch.vtensor<[5],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, false, true, true]> : tensor<5xi1>) : !torch.vtensor<[5],i1> + // CHECK: return %[[CST]] : !torch.vtensor<[5],i1> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<5xi1>} : () -> !torch.vtensor<[5],i1> + return %0 : !torch.vtensor<[5],i1> +} + +{-# + dialect_resources: { + builtin: { + _: "0x080000000100000101" + } + } +#-} + +// ----- + + +// CHECK-LABEL: @test_flatten_4d_axis_2 +func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32> + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> + return %0 : !torch.vtensor<[6,20],f32> +} + +// ----- + +// // CHECK-LABEL: @test_flatten_4d_axis_0 +func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32> + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> + return %0 : !torch.vtensor<[1,120],f32> +} + +// ----- + +// CHECK-LABEL: @test_flatten_4d_axis_4 +func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4 + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor<[2,3,4,5,1],f32> + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> + return %0 : !torch.vtensor<[120,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_flatten_4d_axis_negative_2 +func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32> + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> + return %0 : !torch.vtensor<[6,20],f32> +} + +// ----- + +// CHECK-LABEL: @test_flatten_4d_axis_negative_1 +func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,5],f32> + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> + return %0 : !torch.vtensor<[24,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_flatten_4d_axis_negative_4 +func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32> + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> + return %0 : !torch.vtensor<[1,120],f32> +} + +// ----- + +// CHECK-LABEL: @test_flatten_2d_axis_1 +func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_flatten_1d_axis_0 +func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32> + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> + return %0 : !torch.vtensor<[1,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_flatten_1d_axis_negative_1 +func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32> + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> + return %0 : !torch.vtensor<[1,2],f32> +} + +// ----- + +// COM: CHECK-LABEL: @test_flatten_1d_axis_1 +func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2,1],f32> + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> + return %0 : !torch.vtensor<[2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_float_default +func.func @test_constant_of_shape_dense_float_default() -> !torch.vtensor<[2,3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> : (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], f32> + return %0 : !torch.vtensor<[2,3,4], f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_float_cst +func.func @test_constant_of_shape_dense_float_cst() -> !torch.vtensor<[2,3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 3.4000000953674316 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> {torch.onnx.value = dense<3.4> : tensor<1xf32>}: (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], f32> + return %0 : !torch.vtensor<[2,3,4], f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_int_cst +func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.int 3 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> {torch.onnx.value = dense<3> : tensor<1xsi64>}: (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], si64> + return %0 : !torch.vtensor<[2,3,4], si64> +} + +// CHECK-LABEL: func.func @test_celu +func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %0 = torch.aten.div.Scalar %arg0, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %1 = torch.aten.exp %0 : !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %2 = torch.aten.sub.Scalar %1, %int1, %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %3 = torch.aten.mul.Scalar %2, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %4 = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %none = torch.constant.none +// CHECK: %int6 = torch.constant.int 6 +// CHECK: %[[ZERO:.*]] = torch.aten.full %4, %int0, %int6, %none, %none, %none : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ZERO]], %3 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %[[MAX:.*]] = torch.aten.maximum %[[ZERO]], %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %8 = torch.aten.add.Tensor %[[MAX]], %[[MIN]], %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3,3,3,1],f32>, !torch.int -> !torch.vtensor<[3,3,3,1],f32> + %0 = torch.operator "onnx.Celu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> + return %0 : !torch.vtensor<[3,3,3,1],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir new file mode 100644 index 000000000000..9dceff316eaa --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -0,0 +1,744 @@ +// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch --split-input-file | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + +// CHECK-LABEL: func.func @test_greater +func.func @test_greater(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.gt.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Greater"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_greater_or_equal +func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.ge.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.GreaterOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_less +func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.lt.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Less"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_gather_nd +func.func @test_gather_nd(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[AXIS:.+]] = torch.constant.int 0 + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] + // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 + // CHECK: %[[D0:.+]] = torch.constant.int 0 + // CHECK: %[[SZ0:.+]] = torch.aten.size.int %[[SEL]], %[[D0]] + // CHECK: %[[D1:.+]] = torch.constant.int 1 + // CHECK: %[[SZ1:.+]] = torch.aten.size.int %[[SEL]], %[[D1]] + // CHECK: %[[D2:.+]] = torch.constant.int 2 + // CHECK: %[[SZ2:.+]] = torch.aten.size.int %[[SEL]], %[[D2]] + // CHECK: %[[D3:.+]] = torch.constant.int 3 + // CHECK: %[[SZ3:.+]] = torch.aten.size.int %[[SEL]], %[[D3]] + // CHECK: %[[SZ:.+]] = torch.prim.ListConstruct %[[SZ0]], %[[SZ1]], %[[SZ2]], %[[SZ3]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %[[SEL]] + // CHECK: %[[SUB:.+]] = torch.aten.sub.int %[[DIM]], %[[ONE]] + // CHECK: %[[FLAT:.+]] = torch.aten.flatten.using_ints %[[SEL]], %[[ZERO]], %[[SUB]] + // CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]] + // CHECK: %[[RES:.+]] = torch.aten.unflatten.int %[[ISEL]], %[[AXIS]], %[[SZ]] + // CHECK: return %[[RES]] + %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> + return %0 : !torch.vtensor<[8,10,20,40,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gather_scalar +func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[AXIS:.+]] = torch.constant.int 0 + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] + // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 + // CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]] + // CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32> + // CHECK: return %[[RES]] + %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> + return %0 : !torch.vtensor<[4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gather_elements +func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %arg1, %[[FALSE]] + %0 = torch.operator "onnx.GatherElements"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_defaultA +func.func @test_gemm_defaultA(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_defaultB +func.func @test_gemm_defaultB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_transposeA +func.func @test_gemm_transposeA(%arg0: !torch.vtensor<[5,3],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I1]] : !torch.vtensor<[5,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,5],f32> + // CHECK: %[[MM:.+]] = torch.aten.mm %[[TRANS]], %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.transA = 1 : si64} : (!torch.vtensor<[5,3],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_transposeB +func.func @test_gemm_transposeB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[4,5],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg1, %[[I0]], %[[I1]] : !torch.vtensor<[4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,4],f32> + // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %[[TRANS]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.transB = 1 : si64} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_alpha +func.func @test_gemm_alpha(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.add.Tensor %arg2, %[[MM]], %[[ALPHA]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.alpha = 5.000000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_beta +func.func @test_gemm_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[BETA:.+]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.beta = 5.000000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_alpha_beta +func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[BETA:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[MUL:.+]] = torch.aten.mul.Scalar %[[MM]], %[[ALPHA]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MUL]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 2.500000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL : func.func @test_layer_norm +func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %int3 = torch.constant.int 3 + // CHECK: %int4 = torch.constant.int 4 + // CHECK: %0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list + // CHECK: %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2 + %0:3 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> +} + +// ----- + +// CHECK-LABEL : func.func @test_layer_norm_single_result +func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],f32>, %arg1: !torch.vtensor<[768],f32>, %arg2: !torch.vtensor<[768],f32>) -> (!torch.vtensor<[1,4,768], f32>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %float9.999990e-06 = torch.constant.float 9.9999997473787516E-6 + // CHECK: %int768 = torch.constant.int 768 + // CHECK: %0 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list + // CHECK: %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2 + %0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32} : (!torch.vtensor<[1,4,768],f32>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],f32>) -> !torch.vtensor<[1,4,768],f32> + return %0 : !torch.vtensor<[1,4,768],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_leaky_relu +func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} { + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK: %[[LRELU:.+]] = torch.aten.leaky_relu %arg0, %[[F2]] + %0 = torch.operator "onnx.LeakyRelu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_matmul_2d +func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32> + %0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> + return %0 : !torch.vtensor<[3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_matmul_3d +func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32> + %0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> + return %0 : !torch.vtensor<[2,3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_matmul_4d +func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32> + %0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> + return %0 : !torch.vtensor<[1,2,3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_matmulinteger +func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vtensor<[3,2],ui8>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32> + // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2 + // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3 + // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8> + // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> + // CHECK: %[[MM:.+]] = torch.aten.mm %[[LMAKE]], %[[RMAKE]] + // CHECK: return %[[MM]] + return %0 : !torch.vtensor<[4,2],si32> +} + +// ----- + +// CHECK-LABEL: func.func @test_mul + func.func @test_mul(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Mul"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_default +func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK: %[[I1_3:.*]] = torch.constant.int 1 + // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[I1_2]], %[[I1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST3]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_ceil +func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[I3_1:.*]] = torch.constant.int 3 + // CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,2,2],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> + return %0 : !torch.vtensor<[1,1,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_3d_default +func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_3:.*]] = torch.constant.int 1 + // CHECK: %[[I1_4:.*]] = torch.constant.int 1 + // CHECK: %[[I1_5:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_3]], %[[I1_4]], %[[I1_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_pad +func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 + // CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[PADDED:.+]] = torch.aten.constant_pad_nd %arg0, %[[PADI]], %[[MIN]] : !torch.vtensor<[1,64,111,111],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,64,114,114],f32> + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT3_0:.*]] = torch.constant.int 3 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_4:.*]] = torch.constant.int 2 + // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %[[PADDED]], %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,114,114],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56,56],f32> + // CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> + return %0 : !torch.vtensor<[1,64,56,56],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_symmetric_pad +func.func @test_maxpool_symmetric_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT3_0:.*]] = torch.constant.int 3 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_4:.*]] = torch.constant.int 2 + // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56,56],f32> + // CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> + return %0 : !torch.vtensor<[1,64,56,56],f32> +} + +// ----- + +// CHECK-LABEL: @test_gelu_default_1 +func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "none" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3],f32>, !torch.str -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_gelu_default_2 +func.func @test_gelu_default_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "none" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3,4,5],f32>, !torch.str -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_gelu_tanh_1 +func.func @test_gelu_tanh_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "tanh" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3],f32>, !torch.str -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) {torch.onnx.approximate = "tanh"} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_gelu_tanh_2 +func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "tanh" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3,4,5],f32>, !torch.str -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) {torch.onnx.approximate = "tanh"} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_grid_sampler +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT0_0:.*]] = torch.constant.int 0 +// CHECK: %[[B0:.*]] = torch.constant.bool false +// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> +func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_less_or_equal +func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_pad +func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[STR:.+]] = torch.constant.str "constant" + // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32> + %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_pad_optional_constant +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[VAL:.+]] = torch.constant.float 0 +// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant" +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + +func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_pow + func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// ----- + +// CHECK-LABEL: @test_hardsigmoid_example +func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3],f32> + + %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_hardsigmoid +func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_hardsigmoid_default +func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_globalaveragepool +func.func @test_globalaveragepool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C5_0:.*]] = torch.constant.int 5 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.avg_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,1,1],f32> + %0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> + return %0 : !torch.vtensor<[1,3,1,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_globalaveragepool_precomputed +func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.avg_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> + %0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_max_example + func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Max"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_min_example + func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Min"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_log + func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.log %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Log"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_neg + func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Neg"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %true, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32> + %0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_not_2d +func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Not"(%arg0) : (!torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> + } + +// ----- + +// CHECK-LABEL: func.func @test_nonzero + func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> + %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> + return %0 : !torch.vtensor<[3,4,5],si64> + } + +// ----- + +// CHECK-LABEL: func.func @test_or2d + func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Or"(%arg0, %arg1) : (!torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> + } + +// CHECK-LABEL: func.func @test_identity + func.func @test_identity(%arg0: !torch.vtensor<[3,4], f32>) -> !torch.vtensor<[3,4], f32> attributes {torch.onnx_meta.ir_version = 14 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %0 = torch.aten.clone %arg0, %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Identity"(%arg0) : (!torch.vtensor<[3,4], f32>) -> !torch.vtensor<[3,4], f32> + return %0 : !torch.vtensor<[3,4], f32> + } + +// CHECK-LABEL: func.func @test_mean_one_input + func.func @test_mean_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.Mean"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// CHECK-LABEL: func.func @test_mean_two_inputs + func.func @test_mean_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.div.Scalar %0, %int2 : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Mean"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// CHECK-LABEL: func.func @test_isinf_negative + func.func @test_isinf_negative(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.neg %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32> + // CHECK: torch.aten.relu %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32> + // CHECK: torch.aten.isinf %1 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1> + %0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_positive = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> + return %0 : !torch.vtensor<[6],i1> + } + +// CHECK-LABEL: func.func @test_isinf_positive + func.func @test_isinf_positive(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.relu %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32> + // CHECK: torch.aten.isinf %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1> + %0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_negative = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> + return %0 : !torch.vtensor<[6],i1> + } + +// CHECK-LABEL: func.func @test_isnan + func.func @test_isnan(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.isnan %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1> + %0 = torch.operator "onnx.IsNaN"(%arg0) : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> + return %0 : !torch.vtensor<[6],i1> + } + +// CHECK-LABEL: func.func @test_prelu_example + func.func @test_prelu_example(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// CHECK-LABEL: func.func @test_prelu_broadcast + func.func @test_prelu_broadcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir new file mode 100644 index 000000000000..508ed55d3337 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -0,0 +1,1666 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + +// CHECK-LABEL: @test_quantizelinear_si8 +func.func @test_quantizelinear_si8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> + + // CHECK: %[[C12:.+]] = torch.constant.int 12 + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si8> -> !torch.int + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C12]] + // CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] + // CHECK: return %[[REPR]] + return %0 : !torch.vtensor<[6],si8> +} + +// ----- + +// CHECK-LABEL: @test_quantizelinear_ui8 +func.func @test_quantizelinear_ui8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],ui8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],ui8> + // CHECK: %[[C13:.+]] = torch.constant.int 13 + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C13]] + // CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] + // CHECK: return %[[REPR]] + return %0 : !torch.vtensor<[6],ui8> +} + +// ----- + +// CHECK-LABEL: @test_quantizelinear_i32 +func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[6],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[6],si32> + // CHECK: %[[C14:.+]] = torch.constant.int 14 + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C14]] + // CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] + // CHECK: return %[[REPR]] + return %0 : !torch.vtensor<[6],si32> +} + +// ----- + +// CHECK-LABEL: @test_qlinearconv_nobias +func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] + // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] + // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] + // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %[[NONE]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> + // CHECK: %[[INT13:.+]] = torch.constant.int 13 + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> + return %0 : !torch.vtensor<[1,1,7,7],ui8> +} + +// ----- + +// CHECK-LABEL: @test_qlinearconv_bias +func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] + // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] + // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] + // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.vtensor<[7],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> + // CHECK: %[[INT13:.+]] = torch.constant.int 13 + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> + return %0 : !torch.vtensor<[1,1,7,7],ui8> +} + +// ----- + +// CHECK-LABEL: @test_qlinearmatmul_2D +func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> + // CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4],!torch.quint8> + // CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8> + // CHECK: %[[MM:.+]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,4],!torch.quint8>, !torch.vtensor<[4,3],!torch.quint8> -> !torch.vtensor<[2,3],si32> + // CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,3],!torch.qint32> + // CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,3],!torch.qint32> -> !torch.vtensor<[2,3],f32> + // CHECK: %[[DTY:.+]] = torch.constant.int 13 + // CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,3],!torch.quint8> + // CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,3],!torch.quint8> -> !torch.vtensor<[2,3],ui8> + // CHECK: return %[[OUT]] + return %0 : !torch.vtensor<[2,3],ui8> +} + +// ----- + +// CHECK-LABEL: @test_qlinearmatmul_3D +func.func @test_qlinearmatmul_3D(%arg0: !torch.vtensor<[2,2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[2,4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[2,4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8> + // CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,2,4],!torch.quint8> + // CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[2,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4,3],!torch.quint8> + // CHECK: %[[MM:.+]] = torch.aten.bmm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,2,4],!torch.quint8>, !torch.vtensor<[2,4,3],!torch.quint8> -> !torch.vtensor<[2,2,3],si32> + // CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,2,3],!torch.qint32> + // CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,2,3],!torch.qint32> -> !torch.vtensor<[2,2,3],f32> + // CHECK: %[[DTY:.+]] = torch.constant.int 13 + // CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,2,3],!torch.quint8> + // CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,2,3],!torch.quint8> -> !torch.vtensor<[2,2,3],ui8> + // CHECK: return %[[OUT]] + return %0 : !torch.vtensor<[2,2,3],ui8> +} + +// ----- + +// CHECK-LABEL: func.func @test_reciprocal +func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Reciprocal"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_relu +func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.relu %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Relu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_round +func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: torch.aten.round %arg0 : !torch.vtensor<[15],f32> -> !torch.vtensor<[15],f32> + %0 = torch.operator "onnx.Round"(%arg0) : (!torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> + return %0 : !torch.vtensor<[15],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_scatter_elements_with_axis +func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32> -> !torch.vtensor<[1,5],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> + return %0 : !torch.vtensor<[1,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices +func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.constant.str "add" + // CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> + return %0 : !torch.vtensor<[1,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_scatter_elements_without_axis +func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> + return %0 : !torch.vtensor<[3,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul +func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.constant.str "multiply" + // CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> + return %0 : !torch.vtensor<[1,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sigmoid_example +func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sigmoid %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sigmoid"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sin_example +func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sin"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_tanh_example +func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.tanh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Tanh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sqrt_example +func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sqrt %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sqrt"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sub_bcast +func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sub_example +func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sub +func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sub_uint8 +func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>, !torch.int -> !torch.vtensor<[3,4,5],ui8> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> + return %0 : !torch.vtensor<[3,4,5],ui8> +} + +// ----- + +// CHECK-LABEL: func.func @test_sum_example +func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[SUM:.*]] = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM_1:.*]] = torch.aten.add.Tensor %[[SUM]], %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM_2:.*]] = torch.aten.add.Tensor %[[SUM_1]], %arg3, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sum_one_input +func.func @test_sum_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.Sum"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sum_two_inputs +func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sum"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_where_example +func.func @test_where_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],f32>, %arg2: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> + %0 = torch.operator "onnx.Where"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> + return %0 : !torch.vtensor<[2,2],f32> +} + +// CHECK-LABEL: func.func @test_where_long_example +func.func @test_where_long_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],si64>, %arg2: !torch.vtensor<[2,2],si64>) -> !torch.vtensor<[2,2],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],si64>, !torch.vtensor<[2,2],si64> -> !torch.vtensor<[2,2],si64> + %0 = torch.operator "onnx.Where"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],si64>, !torch.vtensor<[2,2],si64>) -> !torch.vtensor<[2,2],si64> + return %0 : !torch.vtensor<[2,2],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_xor2d +func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_xor3d +func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_xor4d +func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> + return %0 : !torch.vtensor<[3,4,5,6],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_xor_bcast3v1d +func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_xor_bcast4v4d +func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> + return %0 : !torch.vtensor<[3,4,5,6],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_squeeze +func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %6 : !torch.vtensor<[1,3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_squeeze_two_axes +func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT5:.*]] = torch.constant.int 5 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int5 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %9, %int5 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %12 : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[3,1,4,5,1],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_unsqueeze_axis_0 +func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: torch.constant.bool false + // CHECK: torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> + return %0 : !torch.vtensor<[1,3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_unsqueeze_axis_1 +func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> + return %0 : !torch.vtensor<[3,1,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_unsqueeze_axis_2 +func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> + return %0 : !torch.vtensor<[3,4,1,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_unsqueeze_negative_axes +func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[1,3,1,5],f32>, !torch.int -> !torch.vtensor<[1,3,1,1,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> + return %0 : !torch.vtensor<[1,3,1,1,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_unsqueeze_three_axes +func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> + // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor + // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> + return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes +func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> + // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor + // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> + return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_softmax_axis_0 +func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int0, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_softmax_axis_1 +func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int1, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_softmax_axis_2 +func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int2, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_softmax_default_axis +func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int2, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_softmax_large_number +func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int1, %none : !torch.vtensor<[2,4],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,4],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) : (!torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_softmax_negative_axis +func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int2, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_selu +func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 + // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] + %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp +func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> + return %0 : !torch.vtensor<[2,1,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_empty_set_int +func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> + return %0 : !torch.vtensor<[2,1,4],si32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_bool_inputs +func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + // CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1> + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> + return %0 : !torch.vtensor<[4,1],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims +func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1> + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_all_dims_default +func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[MAX]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMAX]] + %0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example +func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool + // CHECK: torch.aten.sum.dim_IntList %arg0, %[[NONE]], %0, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example +func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %false, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example +func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> + return %0 : !torch.vtensor<[3,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero +func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> + return %0 : !torch.vtensor<[2,0,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_keepdims_example +func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example +func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example +func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[DIM:.+]] = torch.constant.int 0 + // CHECK: %[[A0:.+]] = torch.constant.int 0 + // CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]] + // CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 3 + // CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]] + // CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]] + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[TRUE]], %[[NONE]] + // CHECK: return %[[SUM]] + %cst = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_reduce_mean_one_axes_dropdims_example +func.func @test_reduce_mean_one_axes_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[DIM:.+]] = torch.constant.int 0 + // CHECK: %[[A0:.+]] = torch.constant.int 0 + // CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]] + // CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 3 + // CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]] + // CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]] + // CHECK: return %[[SUM]] + %cst = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} +// ----- + +// CHECK-LABEL: @test_reduce_mean_one_axesattr_dropdims_example +func.func @test_reduce_mean_one_axesattr_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[INT3]] + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[MEAN:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]] + // CHECK: return %[[MEAN]] + %0 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes = [1 : si64]} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_empty_set_fp +func.func @test_reduce_min_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> + return %0 : !torch.vtensor<[2,1,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_empty_set_int +func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> + return %0 : !torch.vtensor<[2,1,4],si32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_bool_inputs +func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> + return %0 : !torch.vtensor<[4,1],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims +func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_all_dims_default +func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[MIN]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMIN]] + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_prod_default_axes_keepdims_random +func.func @test_reduce_prod_default_axes_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[RANK:.*]] = torch.aten.dim %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.int + // CHECK: %[[LT:.*]] = torch.aten.lt.int %[[INT0_0]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.*]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[INT0_0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT_0:.*]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL_0:.*]] = torch.aten.Int.bool %[[LT_0]] : !torch.bool -> !torch.int + // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[BOOL_0]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[INT1]], %[[MUL_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT_1:.*]] = torch.aten.lt.int %[[INT2]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL_1:.*]] = torch.aten.Int.bool %[[LT_1]] : !torch.bool -> !torch.int + // CHECK: %[[MUL_1:.*]] = torch.aten.mul.int %[[BOOL_1]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[INT2]], %[[MUL_1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[PROD_0:.*]] = torch.aten.prod.dim_int %arg0, %[[ADD]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[PROD_1:.*]] = torch.aten.prod.dim_int %[[PROD_0]], %[[ADD_0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[PROD_2:.*]] = torch.aten.prod.dim_int %[[PROD_1]], %[[ADD_1]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[PROD_2]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceProd"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_prod_keepdims_random +func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT0_0:.*]] = torch.constant.int 0 +// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[DIM:.*]] = torch.aten.dim %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.int +// CHECK: %[[LT:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[BOOL:.*]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int +// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[BOOL:.*]] = torch.constant.bool true +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[PROD:.*]] = torch.aten.prod.dim_int %arg0, %[[ADD]], %[[BOOL]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> +// CHECK: return %[[PROD]] : !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceProd"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sinh +func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { + // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_split_variable_parts_2d_opset18( +// CHECK-SAME: %[[VAL_INPUT:.*]]: !torch.vtensor<[2,6],f32>, +// CHECK-SAME: %[[VAL_SPLIT:.*]]: !torch.vtensor<[2],si64> +// CHECK: %[[VAL_SPLIT_LIST:.*]] = torch.prim.tolist(%[[VAL_SPLIT]]) : !torch.vtensor<[2],si64> -> !torch.list +// CHECK: %[[VAL_AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_RESULT_LIST:.*]] = torch.aten.split_with_sizes %[[VAL_INPUT]], %[[VAL_SPLIT_LIST]], %[[VAL_AXIS]] : !torch.vtensor<[2,6],f32>, !torch.list, !torch.int -> !torch.list> +// CHECK: %[[VAL_VARIADIC_RETURN_VALUE:.*]]:2 = torch.prim.ListUnpack %[[VAL_RESULT_LIST]] : !torch.list> -> !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_VARIADIC_RETURN_VALUE]]#0, %[[VAL_VARIADIC_RETURN_VALUE]]#1 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32> +func.func @test_split_variable_parts_2d_opset18(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0:2 = torch.operator "onnx.Split"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,6],f32>, !torch.vtensor<[2],si64>) -> (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32>) + return %0#0, %0#1 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_split_2d_uneven_split_opset18( +// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[SPLIT_SIZE:.*]] = torch.constant.int 3 +// CHECK: %[[SPLIT_RESULT:.*]] = torch.aten.split.Tensor %[[INPUT_TENSOR]], %[[SPLIT_SIZE]], %[[AXIS]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int -> !torch.list> +// CHECK: %[[UNPACKED_TENSORS:.*]]:3 = torch.prim.ListUnpack %[[SPLIT_RESULT]] : !torch.list> -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +// CHECK: return %[[UNPACKED_TENSORS]]#0, %[[UNPACKED_TENSORS]]#1, %[[UNPACKED_TENSORS]]#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +// CHECK: } +func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0:3 = torch.operator "onnx.Split"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.num_outputs = 3 : si64} : (!torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_tan +func.func @test_tan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TAN:.+]] = torch.aten.tan %arg0 + %0 = torch.operator "onnx.Tan"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_transpose_default +func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32> + %0 = torch.operator "onnx.Transpose"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> + + // CHECK: return %[[TRANSPOSE]] + return %0 : !torch.vtensor<[4,3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_transpose_all_permutations_4 +func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE0:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32> + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE1:.+]] = torch.aten.transpose.int %[[TRANSPOSE0]], %[[I1]], %[[I2]] : !torch.vtensor<[4,3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,2,3],f32> + %0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [2 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32> + + // CHECK: return %[[TRANSPOSE1]] + return %0 : !torch.vtensor<[4,2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_transpose_dynamic +func.func @test_transpose_dynamic(%arg0: !torch.vtensor<[?,32,5,128],f32>) -> !torch.vtensor<[?,5,32,128],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I1]], %[[I2]] : !torch.vtensor<[?,32,5,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,5,32,128],f32> + %0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,32,5,128],f32>) -> !torch.vtensor<[?,5,32,128],f32> + return %0 : !torch.vtensor<[?,5,32,128],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @test_slice +func.func @test_slice(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>, %arg3: !torch.vtensor<[2],si64>, %arg4: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg4, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,10,5],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg4, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[?,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,10,5],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,5],f32> + return %0 : !torch.vtensor<[3,10,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_slice_default_axes_and_slices +func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[NONE_1:.*]] = torch.constant.none + //CHECK: %[[AXES_DEFAULT_SIZE:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list + //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[CONST_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[CONST_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_2:.*]] = torch.constant.int 2 + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_2:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[CONST_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_slice_default_axes_and_steps +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[20,10,5],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[1],si64> + +// CHECK: %[[ZERO0:.*]] = torch.constant.int 0 +// CHECK-NEXT: %[[ZERO1:.*]] = torch.constant.int 0 +// CHECK-NEXT: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ZERO1]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + +func.func @test_slice_default_axes_and_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_slice_default_steps +func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[NONE:.*]] = torch.constant.none + //CHECK: %[[DEFAULT_SIZE_AMOUNT:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list + //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_1:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 2 + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_2:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[AXES_ELEMENT_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reshape_negative_dim +func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2_0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> + return %0 : !torch.vtensor<[2,6,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reshape_negative_extended_dims +func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT2]], %[[INT3]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> + return %0 : !torch.vtensor<[1,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reshape_one_dim +func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT24:.+]] = torch.constant.int 24 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT24]] : (!torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[24],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> + return %0 : !torch.vtensor<[24],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reshape_reduced_dims +func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT12:.+]] = torch.constant.int 12 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT12]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> + return %0 : !torch.vtensor<[2,12],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reshape_reordered_all_dims +func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT4]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[4,2,3],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> + return %0 : !torch.vtensor<[4,2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reshape_zero_and_negative_dim +func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]], %[[INT1]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,1,4],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> + return %0 : !torch.vtensor<[2,3,1,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_range_float64_type + func.func @test_range_float64_type(%arg0: !torch.vtensor<[],f64>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f64> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f64>, !torch.vtensor<[],f64>, !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> + return %0 : !torch.vtensor<[2],f64> + } + +// ----- + +// CHECK-LABEL: func.func @test_range_float32_type + func.func @test_range_float32_type(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f32> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> + return %0 : !torch.vtensor<[2],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_range_int64_type + func.func @test_range_int64_type(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> + } + +// ----- + +// CHECK-LABEL: func.func @test_range_int32_type + func.func @test_range_int32_type(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si32> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> + return %0 : !torch.vtensor<[2],si32> + } + +// ----- + + // CHECK-LABEL: func.func @test_range_int16_type + func.func @test_range_int16_type(%arg0: !torch.vtensor<[],si16>, %arg1: !torch.vtensor<[],si16>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si16> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si16>, !torch.vtensor<[],si16>, !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> + return %0 : !torch.vtensor<[2],si16> + } + +// ----- + +// CHECK-LABEL : func.func @test_top_k + func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + } + +// ----- + +// CHECK-LABEL: func.func @test_top_k_smallest + func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + } + +// ----- + +// CHECK-LABEL: func.func @test_top_k_negative_axis + func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + } + +// ----- + +// CHECK-LABEL: func.func @test_tile +func.func @test_tile(%arg0: !torch.vtensor<[2, 3, 4],f32>, %arg1: !torch.vtensor<[3], si64>) -> !torch.vtensor<[2,12,4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %7 = torch.aten.tile %arg0, %[[DIM_LIST]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12,4],f32> + %0 = torch.operator "onnx.Tile"(%arg0, %arg1) : (!torch.vtensor<[2, 3, 4],f32>, !torch.vtensor<[3], si64>) -> !torch.vtensor<[2, 12, 4],f32> + return %0 : !torch.vtensor<[2, 12, 4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sign +func.func @test_sign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sign %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Sign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_size +func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 9 : si64} { + // CHECK-DAG %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG %[[D0:.+]] = torch.aten.size.int %arg0, %[[INT0]] + // CHECK-DAG %[[D1:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK-DAG %[[D2:.+]] = torch.aten.size.int %arg0, %[[INT2]] + // CHECK-DAG %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG %[[NONE:.+]] = torch.constant.none + // CHECK-DAG %[[MUL0:.+]] = torch.aten.mul.int %[[D0]], %[[D1]] + // CHECK-DAG %[[MUL1:.+]] = torch.aten.mul.int %[[MUL0]], %[[D3]] + // CHECK-DAG %[[TENSOR:.+]] = torch.aten.tensor.int %[[MUL1]], %[[NONE]], %[[NONE]], %[[FALSE]] + // CHECK return %[[TENSOR]] + %0 = torch.operator "onnx.Size"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32> + return %0 : !torch.vtensor<[],si32> +} + diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir new file mode 100644 index 000000000000..0a8bbfe1a8e3 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -0,0 +1,48 @@ +// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch +// FB OPT OPS from https://github.com/llvm/torch-mlir/issues/2689 + +// ----- +// Fixed unecessarily high since-opset value +func.func @cast_operation(%arg0: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %208 = torch.operator "onnx.Cast"(%arg0) { + torch.onnx.to = 1 : si64 + } : (!torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %208 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- +func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>, + %arg1: !torch.vtensor<[1,64,1],f32>) + -> !torch.vtensor<[1,64,768],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %209 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[1,64,768],f32>, !torch.vtensor<[1,64,1],f32>) -> !torch.vtensor<[1,64,768],f32> + return %209 : !torch.vtensor<[1,64,768],f32> +} + +// ----- +// Fixed. +// this is the onnx opset 1 version of Equal, only int types. +// this used to fail to legalize because the "since" value is set unecessarily high (19) +func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>, + %arg1: !torch.vtensor<[4],si64>) + -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %205 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> + return %205 : !torch.vtensor<[4],i1> +} + + +// ----- +func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) + -> !torch.vtensor<[1,64,1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // The ReduceMean operation as provided. + %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> + return %211 : !torch.vtensor<[1,64,1],f32> +} + +// ----- +// Fixed. +func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>, + %arg1: !torch.vtensor<[],si32>) + -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> + return %212 : !torch.vtensor<[2,3],f64> +} \ No newline at end of file diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index eba7546655e9..f063f234e4e5 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -29,9 +29,23 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.matmul.2d +func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> { + // CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> + // CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32> + // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<8x16xf32>, tensor<16x8xf32>) outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32> + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[8,16],f32>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-NOT: assert -func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> +func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> attributes {torch.assume_strict_symbolic_shapes} { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> @@ -42,7 +56,7 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned( // CHECK: linalg.matmul_unsigned -func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32> +func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32> attributes {torch.assume_strict_symbolic_shapes} { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32> @@ -287,3 +301,41 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat$convert( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T3:.*]] = linalg.generic {{.*}} ins(%[[T2]] : tensor) outs(%{{.*}}: tensor) +// CHECK: %[[T4:.*]] = tensor.concat dim(0) %[[T1]], %[[T3]] : (tensor, tensor) -> tensor +// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = tensor.concat dim(0) %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index 00e408388b2c..bed94f98da2b 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -19,6 +19,8 @@ func.func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[] return %0 : !torch.vtensor<[],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$binary( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -46,6 +48,8 @@ func.func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vt return %0 : !torch.vtensor<[?,?],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$ternary( // CHECK: linalg.generic {indexing_maps = [ // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>, @@ -57,6 +61,8 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch return %0 : !torch.vtensor<[?,?,?],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$with_scalar_capture( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { @@ -75,6 +81,8 @@ func.func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1 return %0 : !torch.vtensor<[?],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$static_1( // CHECK: linalg.generic {indexing_maps = [ // CHECK-SAME: affine_map<(d0) -> (d0)>, @@ -84,3 +92,13 @@ func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vt %1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32> return %1 : !torch.vtensor<[?],f32> } + +// ----- + +// CHECK-LABEL: func.func @elementwise_sinh +// CHECK: linalg.generic +// CHECK: math.sinh +func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> { + %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir new file mode 100644 index 000000000000..d392860fa2c1 --- /dev/null +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -0,0 +1,60 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #map +// CHECK-LABEL: func @grid_sampler +// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> +// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32> +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[DIM_4:.*]] = tensor.dim %[[TC0]], %[[C3]] : tensor<4x10x10x4xf32> +// CHECK-DAG: %[[X2:.*]] = arith.subi %[[DIM:.*]], %[[C1]] : index +// CHECK-DAG: %[[X3:.*]] = arith.subi %[[DIM_4]], %[[C1:.*]] : index +// CHECK-DAG: %[[X4:.*]] = arith.index_cast %[[X2]] : index to i64 +// CHECK-DAG: %[[X5:.*]] = arith.index_cast %[[X3]] : index to i64 +// CHECK-DAG: %[[X6:.*]] = arith.sitofp %[[X4]] : i64 to f32 +// CHECK-DAG: %[[X7:.*]] = arith.sitofp %[[X5]] : i64 to f32 +// CHECK-DAG: %[[X8:.*]] = arith.divf %[[X6]], %[[CST2]] : f32 +// CHECK-DAG: %[[X9:.*]] = arith.divf %[[X7]], %[[CST2]] : f32 +func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vtensor<[4,6,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %true = torch.constant.bool 0 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 0 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[4,10,10,4],f32>, !torch.vtensor<[4,6,8,2],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func @grid_sampler2 +// CHECK: #map +// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32 +// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32 +// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32 +// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32 +// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32 +// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32 +// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32 +// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32 +// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32 +// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32 +// CHECK-DAG: linalg.yield %[[X50]] : f32 +// CHECK: } -> tensor +// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32> +func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %true = torch.constant.bool 0 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 0 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} \ No newline at end of file diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index df19ef7645e8..8a359ed5627d 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -1,7 +1,7 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s -// CHECK-LABEL: func @forward -func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK-LABEL: func @forward_max_pool2d +func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -27,3 +27,49 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?, %4 = torch.aten.max_pool2d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> return %4 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 2 + d5 * 3, d3 * 2 + d6 * 3, d4 * 2 + d7 * 3)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK-LABEL: func @forward_max_pool3d +func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?,?],f32> { + %kernel_size1 = torch.constant.int 8 + %kernel_size2 = torch.constant.int 8 + %kernel_size3 = torch.constant.int 8 + + %stride1 = torch.constant.int 2 + %stride2 = torch.constant.int 2 + %stride3 = torch.constant.int 2 + + %padding1 = torch.constant.int 4 + %padding2 = torch.constant.int 4 + %padding3 = torch.constant.int 4 + + %dilation1 = torch.constant.int 3 + %dilation2 = torch.constant.int 3 + %dilation3 = torch.constant.int 3 + + %false = torch.constant.bool false + %kernel_size = torch.prim.ListConstruct %kernel_size1, %kernel_size2, %kernel_size3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %stride = torch.prim.ListConstruct %stride1, %stride2, %stride3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %padding1, %padding2, %padding3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %dilation1, %dilation2, %dilation3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + + %4 = torch.aten.max_pool3d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> + + // CHECK: %[[MIN_VALUE:.*]] = arith.constant 0xFF800000 : f32 + // CHECK: %[[PADDED_INPUT_TENSOR:.*]] = tensor.pad %{{.*}} low[0, 0, 4, 4, 4] high[0, 0, 4, 4, 4] { + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[MIN_VALUE:.*]] : f32 + // CHECK: } : tensor to tensor + + // CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor) -> tensor + // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { + // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32): + // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 + // CHECK: } -> tensor + return %4 : !torch.vtensor<[?,?,?,?,?],f32> +} diff --git a/test/Conversion/TorchToLinalg/sparse.mlir b/test/Conversion/TorchToLinalg/sparse.mlir new file mode 100644 index 000000000000..5d952fde3509 --- /dev/null +++ b/test/Conversion/TorchToLinalg/sparse.mlir @@ -0,0 +1,36 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-LABEL: func.func @sum( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> +// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[64,64],f32,#[[$CSR]]> -> tensor<64x64xf32, #[[$CSR]]> +// CHECK: linalg.generic {{{.*}}} ins(%[[S]] : tensor<64x64xf32, #[[$CSR]]>) +func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %0 = torch.aten.sum %arg0, %none + : !torch.vtensor<[64,64],f32,#CSR>, !torch.none -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-LABEL: func.func @SpMM( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>, +// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> +// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]> +// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> +// CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>) +func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>, + %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> { + %0 = torch.aten.matmul %arg0, %arg1 + : !torch.vtensor<[8,16],f32,#CSR>, + !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> +} diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 83424a17d843..4f9c1f867ee4 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -148,8 +148,8 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3,5,?,6],f32> // [10,3,?,2,3] -> [30,?,6] -> [2,3,5,?,6] -// Associations are, -// -- for collapse, [0,1], [2], [3,4] and +// Associations are, +// -- for collapse, [0,1], [2], [3,4] and // -- for expand [0,1,2], [3], [4]. func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> { %int3 = torch.constant.int 3 diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index 51e3e6f9bdbb..5f096205ea8c 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -1,21 +1,6 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s -// ----- - -// CHECK-LABEL: func.func @torch.aten.clone$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor -// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %none = torch.constant.none - %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - // ----- // CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { @@ -42,13 +27,9 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { // CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic( // CHECK-SAME: ) -> !torch.vtensor<[],si64> { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor -// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],si64> -// CHECK: return %[[T4]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = stablehlo.constant dense<1> : tensor +// CHECK: %[[FROM:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],si64> +// CHECK: return %[[FROM]] : !torch.vtensor<[],si64> func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { %int1 = torch.constant.int 1 %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> @@ -281,7 +262,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc // ----- // CHECK-LABEL: func.func @torch.aten.cat( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %int0 = torch.constant.int 0 // CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index b9bac97ca6c9..7f253a98df04 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -269,7 +269,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // ----- // CHECK-LABEL: func.func @torch.aten.convolution( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor @@ -306,7 +306,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // ----- // CHECK-LABEL: func.func @torch.aten.convolution$bias( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>, // CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor @@ -349,7 +349,7 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_basic( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> @@ -380,7 +380,7 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_stride( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> @@ -415,7 +415,7 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_outputpadding( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> @@ -450,7 +450,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_groups( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,2,3,3],f32> -> tensor<2x2x3x3xf32> @@ -485,7 +485,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32> // CHECK: %from_elements_3 = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> // CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %from_elements_3 : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> -// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) +// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_18]] : !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index 426a43542477..b8fc6cbd8384 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -18,7 +18,7 @@ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -50,8 +50,8 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) -// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: }) +// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor -// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> // CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> @@ -141,7 +141,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): // CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor // CHECK: stablehlo.return %[[IVAL_2]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor @@ -162,7 +162,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): // CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor // CHECK: stablehlo.return %[[IVAL_5]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> @@ -198,7 +198,7 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: stablehlo.return %[[T10]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor // CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor, tensor) -> tensor diff --git a/test/Conversion/TorchToTensor/torch_to_tensor.mlir b/test/Conversion/TorchToTensor/torch_to_tensor.mlir new file mode 100644 index 000000000000..277dabc3b891 --- /dev/null +++ b/test/Conversion/TorchToTensor/torch_to_tensor.mlir @@ -0,0 +1,8 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tensor | FileCheck %s + +// CHECK-LABEL: func.func @test_shape +func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3],si64> { + // CHECK: %[[SHAPE:.+]] = arith.constant dense<[3, 4, 5]> : tensor<3xi64> + %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3],si64> + return %0 : !torch.vtensor<[3],si64> +} diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 7d046177fc14..e57467ba2416 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -96,24 +96,6 @@ func.func @torch.aten.matmul_4d(%arg0 : !torch.vtensor<[10,10,6,2],f32>, %arg1 : // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<10x6x2xf32>) -> tensor<1x10x6x2xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.transpose %[[VAL_2]], %[[VAL_3]] : (tensor<1x10x6x2xf32>, tensor<4xi32>) -> tensor<10x1x6x2xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<10x1x6x2xf32>) -> tensor<10x6x2xf32> -// CHECK-NEXT: %[[VAL_6:.+]] = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_7:.+]] = tosa.transpose %1, %[[VAL_6]] : (tensor<10x10x2x6xf32>, tensor<4xi32>) -> tensor<10x2x10x6xf32> -// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<10x2x10x6xf32>) -> tensor<10x2x60xf32> -// CHECK-NEXT: %[[VAL_9:.+]] = tosa.matmul %[[VAL_5]], %[[VAL_8]] : (tensor<10x6x2xf32>, tensor<10x2x60xf32>) -> tensor<10x6x60xf32> -// CHECK-NEXT: %[[VAL_10:.+]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<10x6x60xf32>) -> tensor<10x6x10x6xf32> -// CHECK-NEXT: %[[VAL_11:.+]] = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_12:.+]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<10x6x10x6xf32>, tensor<4xi32>) -> tensor<10x10x6x6xf32> -func.func @torch.aten.matmul_4d_broadcast(%arg0 : !torch.vtensor<[10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[?,?,?,?],f32> { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> -} - -// ----- - // CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<4x1x5x6xf32>) -> tensor<1x20x6xf32> // CHECK-NEXT: %[[VAL_3:.+]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-NEXT: %[[VAL_4:.+]] = tosa.transpose %1, %[[VAL_3]] : (tensor<1x3x6x7xf32>, tensor<4xi32>) -> tensor<6x1x3x7xf32> @@ -143,17 +125,16 @@ func.func @torch.aten.matmul_3d_broadcast(%arg0 : !torch.vtensor<[100,4,8],f32>, // ----- -// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %[[VAL_2]] : (tensor<100x4x16xf32>) -> tensor<100x4x16xf16> +// CHECK-LABEL: torch.aten.bmm_3d_fp16 +// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf16> func.func @torch.aten.bmm_3d_fp16(%arg0 : !torch.vtensor<[100,4,8],f16>, %arg1 : !torch.vtensor<[100,8,16],f16>) -> !torch.vtensor<[?,?,?],f16> { %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f16>, !torch.vtensor<[100,8,16],f16> -> !torch.vtensor<[?,?,?],f16> return %0 : !torch.vtensor<[?,?,?],f16> } // ----- - -// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xbf16>, tensor<100x8x16xbf16>) -> tensor<100x4x16xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %[[VAL_2]] : (tensor<100x4x16xf32>) -> tensor<100x4x16xbf16> +// CHECK-LABEL: torch.aten.bmm_3d_bf16 +// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xbf16>, tensor<100x8x16xbf16>) -> tensor<100x4x16xbf16> func.func @torch.aten.bmm_3d_bf16(%arg0 : !torch.vtensor<[100,4,8],bf16>, %arg1 : !torch.vtensor<[100,8,16],bf16>) -> !torch.vtensor<[?,?,?],bf16> { %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],bf16>, !torch.vtensor<[100,8,16],bf16> -> !torch.vtensor<[?,?,?],bf16> return %0 : !torch.vtensor<[?,?,?],bf16> @@ -406,6 +387,34 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // ----- +// CHECK-LABEL: func.func @test_linalg_vector_norm$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { +// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32> +// CHECK: %[[ARG1:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[ARG2:.*]] = torch.constant.int -1 +// CHECK: %[[ARG3:.*]] = torch.constant.bool true +// CHECK: %[[ARG4:.*]] = torch.constant.none +// CHECK: %[[ARG5:.*]] = torch.prim.ListConstruct %[[ARG2]] : (!torch.int) -> !torch.list +// CHECK: %[[ARG6:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[ARG7:.*]] = tosa.abs %[[ARG0_BUILTIN]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[ARG8:.*]] = tosa.pow %[[ARG7]], %[[ARG6]] : (tensor<3x151x64xf32>, tensor) -> tensor<3x151x64xf32> +// CHECK: %[[ARG9:.*]] = tosa.reduce_sum %[[ARG8]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[ARG10:.*]] = tosa.reciprocal %[[ARG6]] : (tensor) -> tensor +// CHECK: %[[ARG11:.*]] = tosa.pow %[[ARG9]], %[[ARG10]] : (tensor<3x151x1xf32>, tensor) -> tensor<3x151x1xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ARG11]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[3,151,1],f32> +func.func @test_linalg_vector_norm$basic(%arg0: !torch.vtensor<[3,151,64],f32>) -> (!torch.vtensor<[3,151,1],f32>) { + %float2.000000e00 = torch.constant.float 2.000000e+00 + %int-1 = torch.constant.int -1 + %true = torch.constant.bool true + %none = torch.constant.none + %1 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %2 = torch.aten.linalg_vector_norm %arg0, %float2.000000e00, %1, %true, %none : !torch.vtensor<[3,151,64],f32>, !torch.float, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,151,1],f32> + return %2 : !torch.vtensor<[3,151,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor @@ -790,6 +799,22 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- +// CHECK-LABEL: func.func @torch.aten.logical_or$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_or %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + // CHECK-LABEL: func.func @forward( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> @@ -1201,6 +1226,63 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> return %0 : !torch.vtensor<[1,1,128,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.negative_start( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 100 +// CHECK: %[[VAL_5:.*]] = torch.constant.int -16 +// CHECK: %[[VAL_1r:.*]] = tosa.reshape +// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1r]] {size = array, start = array} : (tensor<4x65x1x256xf32>) -> tensor<4x16x1x256xf32> +// CHECK: %[[VAL_4r:.*]] = tosa.reshape +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4r]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32> +// CHECK: } +func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int100 = torch.constant.int 100 + %int-16 = torch.constant.int -16 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int-16, %int100, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,16,256],f32> + return %0 : !torch.vtensor<[4,16,256],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.min_none( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.000000e+00 : f32, max_int = 0 : i64, min_fp = -3.40282347E+38 : f32, min_int = -9223372036854775808 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: } +func.func @torch.aten.clamp.min_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.clamp %arg0, %none, %int0 : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.int -> !torch.vtensor<[1,1,128,128],si64> + return %0 : !torch.vtensor<[1,1,128,128],si64> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.max_none( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: } +func.func @torch.aten.clamp.max_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.clamp %arg0, %int0, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,1,128,128],si64> + return %0 : !torch.vtensor<[1,1,128,128],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.clamp( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { @@ -1224,7 +1306,7 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],f32> -> tensor<1x1x128x128xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 -// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 6.432100e+00 : f32, max_int = 0 : i64, min_fp = 3.123400e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 6.432100e+00 : f32, max_int = 6 : i64, min_fp = 3.123400e+00 : f32, min_int = 3 : i64} : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xf32> -> !torch.vtensor<[1,1,128,128],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],f32> // CHECK: } diff --git a/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir new file mode 100644 index 000000000000..5504ac0e4002 --- /dev/null +++ b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir @@ -0,0 +1,12 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file + +// CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16> +func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> { + %int5 = torch.constant.int 5 + %false = torch.constant.bool false + %none = torch.constant.none + %out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16> + return %out : !torch.vtensor<[1,32,220,220],f16> +} + + diff --git a/test/Conversion/TorchToTosa/conv2d_transpose.mlir b/test/Conversion/TorchToTosa/conv2d_transpose.mlir new file mode 100644 index 000000000000..7f0d5e2ab25b --- /dev/null +++ b/test/Conversion/TorchToTosa/conv2d_transpose.mlir @@ -0,0 +1,18 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics + +// The following test ensures that a tranposed convolution op is not +// lowered in the torch-to-tosa conversion pass. + +func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32> + %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> + %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}} + %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> + return %output : !torch.vtensor<[1,64,2,200],f32> +} + diff --git a/test/Dialect/TMTensor/bufferize.mlir b/test/Dialect/TMTensor/bufferize.mlir index 3e60814fabdd..f36a2f521ad1 100644 --- a/test/Dialect/TMTensor/bufferize.mlir +++ b/test/Dialect/TMTensor/bufferize.mlir @@ -64,7 +64,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> -// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32 @@ -74,7 +74,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: func.func @scatter_update_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, %updates: tensor<3xi32>) -> tensor<8xi32> { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) outs(%original : tensor<8xi32>) { ^bb0(%update: i32, %orig: i32): // no predecessors @@ -92,7 +92,7 @@ func.func @scatter_update_scalar_1D( // CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> -// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: %[[CST1:.*]] = arith.constant 1 : i32 @@ -104,7 +104,7 @@ func.func @scatter_update_scalar_1D( func.func @scatter_add_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, %updates: tensor<3xi32>) -> tensor<8xi32> { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) outs(%original : tensor<8xi32>) { ^bb0(%update: i32, %orig: i32): // no predecessors diff --git a/test/Dialect/TMTensor/convert_to_loops.mlir b/test/Dialect/TMTensor/convert_to_loops.mlir index e9c160f99e94..7901cf505f2a 100644 --- a/test/Dialect/TMTensor/convert_to_loops.mlir +++ b/test/Dialect/TMTensor/convert_to_loops.mlir @@ -105,7 +105,7 @@ func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { func.func @scatter_update_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -131,7 +131,7 @@ func.func @scatter_update_scalar_1D( func.func @scatter_add_scalar_2D( %original: memref<4x3xi32>, %indices: memref<3x2xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -162,7 +162,7 @@ func.func @scatter_add_scalar_2D( func.func @scatter_update_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -192,7 +192,7 @@ func.func @scatter_update_slice_2D( func.func @scatter_add_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -221,7 +221,7 @@ func.func @scatter_add_scalar_1D( func.func @scatter_add_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -251,7 +251,7 @@ func.func @scatter_add_slice_2D( func.func @scatter_update_scalar_dynamic_1D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -277,7 +277,7 @@ func.func @scatter_update_scalar_dynamic_1D( func.func @scatter_add_scalar_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -308,7 +308,7 @@ func.func @scatter_add_scalar_dynamic_2D( func.func @scatter_update_slice_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -335,6 +335,7 @@ func.func @scatter_update_slice_dynamic_2D( func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) { tm_tensor.scatter + {dimension_map= array} unique_indices(true) ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>) outs(%arg0 : memref<2x64x12xf32>) { diff --git a/test/Dialect/TMTensor/invalid.mlir b/test/Dialect/TMTensor/invalid.mlir index bfcd1adb8152..6653d944a059 100644 --- a/test/Dialect/TMTensor/invalid.mlir +++ b/test/Dialect/TMTensor/invalid.mlir @@ -4,7 +4,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -20,7 +20,7 @@ func.func @scatter_mixed_tensor_memref( %update : tensor, %indices : memref, %original : tensor) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, memref) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -36,7 +36,7 @@ func.func @scatter_extra_outputs( %update : tensor, %indices : tensor, %original : tensor) -> (tensor, tensor) { // expected-error @+1 {{expected number of outputs to be same as the number of results}} - %0, %1 = tm_tensor.scatter unique_indices(true) + %0, %1 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -52,7 +52,7 @@ func.func @scatter_mixed_tensor_memref( %update : tensor, %indices : tensor, %original : memref) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : memref) { ^bb0(%arg1: f32, %arg2: f32): @@ -68,7 +68,7 @@ func.func @scatter_output_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor<4x?xf32> { // expected-error @+1 {{expected type of `outs` operand #0 'tensor' to be same as result type 'tensor<4x?xf32>'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -84,7 +84,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : tensor, %original : memref) { // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, tensor) outs(%original : memref) { ^bb0(%arg1: f32, %arg2: f32): @@ -100,7 +100,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : memref, %original : tensor) { // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, memref) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -116,7 +116,7 @@ func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor<48x1xi32>) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -132,7 +132,7 @@ func.func @scatter_dim_mismatch( %update : tensor<64x?xf32>, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -148,7 +148,7 @@ func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{op update value rank exceeds the rank of the original value}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -162,16 +162,16 @@ func.func @scatter_dim_mismatch( func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}} - %0 = tm_tensor.scatter unique_indices(true) + %original : tensor) -> tensor { + // expected-error @+1 {{shape of update value dim#1 exceeds original value at dim#1}} + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { + outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): %1 = arith.addf %arg1, %arg2 : f32 tm_tensor.yield %1 : f32 - } -> tensor - return %0 : tensor + } -> tensor + return %0 : tensor } // ----- @@ -180,7 +180,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected region to have scalar argument of integer or float types}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: index, %arg2: index): @@ -197,7 +197,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i32): @@ -214,7 +214,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -231,7 +231,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -248,7 +248,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected region to have two arguments}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64): @@ -264,7 +264,7 @@ func.func @scatter_region_type_mismatch( func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -281,7 +281,7 @@ func.func @scatter_yield_mismatch( func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -299,7 +299,7 @@ func.func @scatter_index_depth_dynamic( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected index depth is static}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -316,7 +316,7 @@ func.func @scatter_original_rank_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{op index depth and update value does not cover rank of original value}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir index eacd36493791..7fc261850def 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir @@ -1,6 +1,6 @@ // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s -// Check that linkage names consist of the dotted path from the root. +// Check that linkage names consist of the dotted path from the root. // CHECK-LABEL: torch.global_slot.module_initializer { // CHECK: %[[FLOAT:.*]] = torch.constant.float 4.200000e+01 diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 5dfd8daa9d44..a607365f4918 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1,10 +1,10 @@ -// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s +// RUN: torch-mlir-opt %s -canonicalize --split-input-file | FileCheck %s // CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INTM1:.*]] = torch.constant.int -1 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1 // CHECK: %[[NEG_STEP:.*]] = torch.aten.__range_length %[[INT1]], %[[INT3]], %[[INTM1]] : !torch.int, !torch.int, !torch.int -> !torch.int // CHECK: return %[[INT2]], %[[INT2]], %[[INT1]], %[[NEG_STEP]] : !torch.int, !torch.int, !torch.int, !torch.int func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) { @@ -29,6 +29,46 @@ func.func @torch.runtime.assert() { return } +// CHECK-LABEL: func.func @torch.aten.ones_item +// CHECK: %[[CONST:.*]] = torch.constant.int 1 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.ones_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.ones %0, %int3, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.zeros_item +// CHECK: %[[CONST:.*]] = torch.constant.int 0 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.zeros_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.zeros %0, %int3, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.full_item +// CHECK: %[[CONST:.*]] = torch.constant.int 1337 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.full_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 1337 + %int5 = torch.constant.int 5 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.full %0, %int3, %int5, %none, %none, %none : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + // CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool @@ -1421,6 +1461,17 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to return %0 : !torch.tensor<[],f32> } +// CHECK-LABEL: func.func @torch.aten.tensor$one_elem( +// CHECK-NEXT: torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) { + %none = torch.constant.none + %false = torch.constant.bool false + %int42 = torch.constant.int 42 + %66 = torch.prim.ListConstruct %int42 : (!torch.int) -> !torch.list + %67 = torch.aten.tensor %66, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + return %67 : !torch.vtensor<[1],si64> +} + // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32> @@ -1636,13 +1687,8 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool { } // CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1654,11 +1700,8 @@ func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1709,11 +1752,8 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %str = torch.constant.str "floor" @@ -1724,13 +1764,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtenso } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %int2 = torch.constant.int 2 @@ -1742,11 +1777,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vt } // CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1757,9 +1789,8 @@ func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1769,13 +1800,8 @@ func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1787,11 +1813,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1803,11 +1826,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1818,9 +1838,8 @@ func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1840,9 +1859,8 @@ func.func @torch.aten.sub.float$fold() -> !torch.float { } // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int3 = torch.constant.int 3 %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> @@ -1851,11 +1869,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1865,9 +1880,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> @@ -1876,13 +1890,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1893,13 +1902,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %int2 = torch.constant.int 2 @@ -1911,11 +1915,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !to } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %str = torch.constant.str "trunc" @@ -1961,6 +1962,36 @@ func.func @torch.aten.sort.int$reverse_true() -> !torch.list { return %0 : !torch.list } +// CHECK-LABEL: @torch.aten.sort$unary_element +// CHECK : %[[INDICES:.*]] = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NOT : torch.aten.sort %arg +// CHECK : return %arg0, %[[INDICES]] : !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> +func.func @torch.aten.sort$unary_element(%arg0 : !torch.vtensor<[1],si64>, %arg1 : !torch.int, %arg2 : !torch.bool) -> (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) { + %0, %1 = torch.aten.sort %arg0, %arg1, %arg2 : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + return %0, %1 : !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> +} + + +// CHECK-LABEL: @torch.aten.sort$unary_dim +// CHECK : %[[INDICES:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NOT : torch.aten.sort %arg +// CHECK : return %arg0, %[[INDICES]] : !torch.vtensor<[3, 1,4],si64>, !torch.vtensor<[1],si64> +func.func @torch.aten.sort$unary_dim(%arg0 : !torch.vtensor<[3, 1, 4],si64>, %arg1 : !torch.bool) -> (!torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[1],si64>) { + %dim = torch.constant.int 1 + %0, %1 = torch.aten.sort %arg0, %dim, %arg1 : !torch.vtensor<[3, 1, 4],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[1],si64> + return %0, %1 : !torch.vtensor<[3, 1,4],si64>, !torch.vtensor<[1],si64> +} + +// CHECK-LABEL: @torch.aten.sort$nofold +// CHECK : torch.aten.sort %arg +func.func @torch.aten.sort$nofold (%arg0 : !torch.vtensor<[3, 1, 4],si64>, %arg1 : !torch.bool) -> (!torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64>) { + %dim = torch.constant.int 0 + %0, %1 = torch.aten.sort %arg0, %dim, %arg1 : !torch.vtensor<[3, 1, 4],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64> + return %0, %1 : !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64> +} + +// ----- + // CHECK-LABEL: @torch.aten.cat$fold_single_operand // CHECK-SAME: %[[ARG0:.+]]: !torch.tensor // CHECK: return %[[ARG0]] : !torch.tensor @@ -1971,6 +2002,22 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te return %1: !torch.tensor } +// ----- + +// CHECK-LABEL: @torch.aten.cat$fold_zero_dim_operand +// CHECK: %[[FOLD:.+]] = torch.vtensor.literal(dense<[1, 3, 2, 2]> : tensor<4xsi32>) +// CHECK: return %[[FOLD]] : !torch.vtensor +func.func @torch.aten.cat$fold_zero_dim_operand() -> !torch.vtensor<[4],si32> { + %0 = torch.vtensor.literal(dense<[1, 3]> : tensor<2xsi32>) : !torch.vtensor<[2],si32> + %1 = torch.vtensor.literal(dense<2> : tensor<2xsi32>) : !torch.vtensor<[2],si32> + %int0 = torch.constant.int 0 + %list = torch.prim.ListConstruct %0, %1 : (!torch.vtensor<[2],si32>, !torch.vtensor<[2],si32>) -> !torch.list + %cat = torch.aten.cat %list, %int0 : !torch.list, !torch.int -> !torch.vtensor<[4],si32> + return %cat: !torch.vtensor<[4],si32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.broadcast_to$fold( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> @@ -1983,15 +2030,23 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> ! return %0 : !torch.vtensor<[3,4,2],f32> } -// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32> -// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32> -func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { - %list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list - %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?],f32> - return %0 : !torch.vtensor<[?],f32> +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold_splat +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3.000000e+00> : tensor<3x4x2xf32>) : !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[CST]] +func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> { + %tensor = torch.vtensor.literal(dense<3.0> : tensor<1x4x1xf32>) : !torch.vtensor<[1,4,1],f32> + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %tensor, %list : !torch.vtensor<[1,4,1],f32>, !torch.list -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> } +// ----- + // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> // CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> @@ -2024,10 +2079,64 @@ func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, return %0 : !torch.vtensor<[?],f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64>) { +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<50> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64> +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<70> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64> +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: return %[[RET_0]], %[[RET_1]] +func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, !torch.vtensor<[1, 1],si64>) { + %tensor = torch.vtensor.literal(dense<[[10,20,30,40,50,60,70,80,90,100]]> : tensor<1x10xsi64>) : !torch.vtensor<[1, 10],si64> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %int6 = torch.constant.int 6 + %int7 = torch.constant.int 7 + %dim = torch.constant.int 1 + %0 = torch.aten.slice.Tensor %tensor, %dim, %int4, %int5, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64> + %1 = torch.aten.slice.Tensor %tensor, %dim, %int6, %int7, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64> + return %0, %1 : !torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64> +} + + +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_small() -> !torch.vtensor<[2],si32> { +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[3, 5]> : tensor<2xsi32>) : !torch.vtensor<[2],si32> +// CHECK: return %[[CST]] +func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) { + %tensor = torch.vtensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi32>) : !torch.vtensor<[10],si32> + %dim = torch.constant.int 0 + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int7 = torch.constant.int 7 + %0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int7, %int2 : !torch.vtensor<[10], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32> + return %0 : !torch.vtensor<[2],si32> +} + +// ----- + +func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { + %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %intn7 = torch.constant.int -7 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %int6 = torch.constant.int 6 + %dim = torch.constant.int 0 + %0 = torch.aten.slice.Tensor %tensor, %dim, %intn7, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> + %1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> + return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -2037,11 +2146,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[] } // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 @@ -2053,7 +2159,6 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso // CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> // CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number // CHECK: return %[[VAL_1]] : !torch.number func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { @@ -2073,6 +2178,52 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number return %1 : !torch.number } +// ----- + +// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: return %[[FLOAT1]] : !torch.float +func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { + %float1 = torch.constant.float 1.0 + %0 = torch.prim.NumToTensor.Scalar %float1 : !torch.float -> !torch.vtensor<[],f64> + %1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float + return %1 : !torch.float +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float { +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: return %[[FLOAT1]] : !torch.float +func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float { + %0 = torch.vtensor.literal(dense<1.0> : tensor) : !torch.vtensor<[],f64> + %1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float + return %1 : !torch.float +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: return %[[INT1]] : !torch.int +func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int { + %int1 = torch.constant.int 1 + %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int + return %1 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: return %[[INT1]] : !torch.int +func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int { + %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int + return %1 : !torch.int +} + +// ----- + // CHECK-LABEL: func.func @torch.prims.view_of$fold( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> @@ -2129,7 +2280,7 @@ func.func @torch.aten.floor$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> ! } // CHECK-LABEL: func.func @torch.aten.numel$canonicalize -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32> // CHECK-NEXT: %int12 = torch.constant.int 12 // CHECK-NEXT: return %int12 : !torch.int func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.int { @@ -2146,3 +2297,539 @@ func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,? %1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> return %1 : !torch.vtensor<[?,?],f32> } + +// CHECK-LABEL: func.func @torch.aten.detach$canonicalize +// CHECK-NEXT: torch.aten.detach +func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !torch.tensor { + %1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor + return %1 : !torch.tensor +} + +// CHECK-LABEL: func.func @torch.aten.index_select$noop( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64> +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[1,2,3],si64> +func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,2,3],si64> { + %0 = torch.aten.index_select %arg0, %arg1, %arg2 : !torch.vtensor<[1,2,3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1,2,3],si64> + return %0 : !torch.vtensor<[1,2,3],si64> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_si_si( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64> +func.func @torch.aten.index_select$const_si_si() -> !torch.vtensor<[1],si64> { + %tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64> + %dim = torch.constant.int 0 + %index = torch.vtensor.literal(dense<5> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_si_ui( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64> +func.func @torch.aten.index_select$const_si_ui() -> !torch.vtensor<[1],si64> { + %tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64> + %dim = torch.constant.int 0 + %index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_ui( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<6.6{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32> +func.func @torch.aten.index_select$const_f32_ui() -> !torch.vtensor<[1],f32> { + %tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32> + %dim = torch.constant.int 0 + %index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_si_neg( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<7.{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32> +func.func @torch.aten.index_select$const_f32_si_neg() -> !torch.vtensor<[1],f32> { + %tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32> + %dim = torch.constant.int -1 + %index = torch.vtensor.literal(dense<-4> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_attr +func.func @fold_aten_where_true_attr() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<11> : tensor) : !torch.vtensor<[],si64> + %where = torch.aten.where.self %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_prim_numtotensor_scalar +func.func @fold_prim_numtotensor_scalar() -> !torch.vtensor<[1],si64> { + %int42 = torch.constant.int 42 + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[TENSOR]] + %0 = torch.prim.NumToTensor.Scalar %int42 : !torch.int -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_attr +func.func @fold_aten_where_false_attr() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<11> : tensor) : !torch.vtensor<[],si64> + %where = torch.aten.where.self %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_value +func.func @fold_aten_where_true_value(%arg0 : !torch.vtensor<[4],si64>, %arg1 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + // CHECK: return %arg0 + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %where = torch.aten.where.self %bool, %arg0, %arg1 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_value +func.func @fold_aten_where_false_value(%arg0 : !torch.vtensor<[4],si64>, %arg1 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + // CHECK: return %arg1 + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %where = torch.aten.where.self %bool, %arg0, %arg1 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_value_nofold +func.func @fold_aten_where_true_value_nofold(%arg0 : !torch.vtensor<[],si64>, %arg1 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + // CHECK: torch.aten.where.self + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %where = torch.aten.where.self %bool, %arg0, %arg1 : !torch.vtensor<[4],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_scalar_int +func.func @fold_aten_where_true_scalar_int() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.constant.int 11 + %where = torch.aten.where.Scalar %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.int -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_scalar_int +func.func @fold_aten_where_false_scalar_int() -> !torch.vtensor<[4],ui8> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.constant.int 11 + %where = torch.aten.where.Scalar %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.int -> !torch.vtensor<[4],ui8> + return %where : !torch.vtensor<[4],ui8> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_scalar_fp +func.func @fold_aten_where_false_scalar_fp() -> !torch.vtensor<[4],f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.100000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.float 7.0 + %rhs = torch.constant.float 11.0 + %where = torch.aten.where.Scalar %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.float, !torch.float -> !torch.vtensor<[4],f32> + return %where : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_sother_int +func.func @fold_aten_where_true_sother_int() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.constant.int 11 + %where = torch.aten.where.ScalarOther %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sother_int +func.func @fold_aten_where_false_sother_int() -> !torch.vtensor<[4],ui8> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor) : !torch.vtensor<[],ui8> + %rhs = torch.constant.int 11 + %where = torch.aten.where.ScalarOther %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[],ui8>, !torch.int -> !torch.vtensor<[4],ui8> + return %where : !torch.vtensor<[4],ui8> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sother_fp +func.func @fold_aten_where_false_sother_fp() -> !torch.vtensor<[4],f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.100000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.constant.float 11.0 + %where = torch.aten.where.ScalarOther %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %where : !torch.vtensor<[4],f32> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_sself_int +func.func @fold_aten_where_true_sself_int() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %where = torch.aten.where.ScalarSelf %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sself_int +func.func @fold_aten_where_false_sself_int() -> !torch.vtensor<[4],ui8> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.vtensor.literal(dense<11> : tensor) : !torch.vtensor<[],ui8> + %where = torch.aten.where.ScalarSelf %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.vtensor<[],ui8> -> !torch.vtensor<[4],ui8> + return %where : !torch.vtensor<[4],ui8> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sself_fp +func.func @fold_aten_where_false_sself_fp() -> !torch.vtensor<[4],f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.100000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.float 7.0 + %rhs = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %where = torch.aten.where.ScalarSelf %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.float, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> + return %where : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @aten_select_int_fold_splat +func.func @aten_select_int_fold_splat(%arg0 : !torch.int, %arg1 : !torch.int) -> !torch.vtensor<[1],si64> { + %splat = torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %select = torch.aten.select.int %splat, %arg0, %arg1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<4> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[RET]] + return %select : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @aten_select_int_fold_1D +func.func @aten_select_int_fold_1D() -> !torch.vtensor<[1],si64> { + %index = torch.constant.int 1 + %dim = torch.constant.int 0 + %splat = torch.vtensor.literal(dense<[5,6,7,8]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %select = torch.aten.select.int %splat, %dim, %index : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<6> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[RET]] + return %select : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @aten_select_int_fold_3D +func.func @aten_select_int_fold_3D() -> !torch.vtensor<[1, 1, 1],si64> { + %index = torch.constant.int 2 + %dim = torch.constant.int 2 + %splat = torch.vtensor.literal(dense<[[[5,6,7,8]]]> : tensor<1x1x4xsi64>) : !torch.vtensor<[1,1,4],si64> + %select = torch.aten.select.int %splat, %dim, %index : !torch.vtensor<[1,1,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,1,1],si64> + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<1x1x1xsi64>) : !torch.vtensor<[1,1,1],si64> + // CHECK: return %[[RET]] + return %select : !torch.vtensor<[1,1,1],si64> +} + +// ----- + + +// CHECK-LABEL: @aten_eq_tensor_args +func.func @aten_eq_tensor_args(%arg0 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %0 = torch.aten.eq.Tensor %arg0, %arg0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_int_false +func.func @aten_eq_tensor_splats_int_false() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_int_true +func.func @aten_eq_tensor_splats_int_true() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_fp_false +func.func @aten_eq_tensor_splats_fp_false() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<4.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_fp_true +func.func @aten_eq_tensor_splats_fp_true() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splat_dense_fp +func.func @aten_eq_tensor_splat_dense_fp() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[false, true, false, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<[4.0, 5.0, 6.0, 5.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_dense_fp +func.func @aten_eq_tensor_dense_fp() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[true, false, true, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<[4.0, 5.5, 6.0, 6.4]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<[4.0, 5.0, 6.0, 5.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splat_dense_int +func.func @aten_eq_tensor_splat_dense_int() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[false, true, false, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<[4, 5, 6, 5]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_dense_int +func.func @aten_eq_tensor_dense_int() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[true, true, true, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<[4, 5, 6, 6]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<[4, 5, 6, 5]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_shape_to_tensor +func.func @aten_shape_to_tensor(%arg0 : !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[3],si32> { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[4, 5, 6]> : tensor<3xsi32>) : !torch.vtensor<[3],si32> + %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[4,5,6],f32> -> !torch.vtensor<[3],si32> + // CHECK: return %[[CST]] + return %0 : !torch.vtensor<[3],si32> +} + +// ----- + +// CHECK-LABEL: @aten_cat_zero +func.func @aten_cat_zero(%arg0 : !torch.vtensor<[4,5,6],f32>, %arg1 : !torch.vtensor<[4,0,6],f32>) -> !torch.vtensor<[4,5,6],f32> { + // CHECK: return %arg0 : !torch.vtensor<[4,5,6],f32> + %list = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[4,5,6],f32>, !torch.vtensor<[4,0,6],f32>) -> !torch.list + %dim = torch.constant.int -2 + %0 = torch.aten.cat %list, %dim : !torch.list, !torch.int -> !torch.vtensor<[4,5,6],f32> + return %0 : !torch.vtensor<[4,5,6],f32> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_scalar_lt +func.func @aten_tensor_scalar_lt() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[CST]], %[[CST]] : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> + %intTensor = torch.vtensor.literal(dense<1> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %fpTensor = torch.vtensor.literal(dense<1.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 2 + %fpScalar = torch.constant.float 2.0 + %intBool = torch.aten.lt.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.lt.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_lt +func.func @aten_tensor_tensor_lt() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[true, false, false, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.lt.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.lt.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.lt.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_le +func.func @aten_tensor_tensor_le() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[true, true, false, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.le.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.le.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.le.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_ge +func.func @aten_tensor_tensor_ge() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[false, true, true, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.ge.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.ge.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.ge.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_gt +func.func @aten_tensor_tensor_gt() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[false, false, true, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.gt.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.gt.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.gt.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_eq +func.func @aten_tensor_tensor_eq() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[false, true, false, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.eq.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.eq.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.eq.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_ne +func.func @aten_tensor_tensor_ne() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[true, false, true, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.ne.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.ne.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.ne.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index e5d5ca19d8a2..0e863ffdfe09 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -123,8 +123,8 @@ func.func @torch.aten.acos$float_type(%arg0: !torch.vtensor<[2, 2],f32>, %arg1: // CHECK-LABEL: func.func @torch.aten.type_as$basic( // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[NONE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int // CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor // CHECK: return %[[VAR]] : !torch.tensor diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir new file mode 100644 index 000000000000..f98cb842f5d3 --- /dev/null +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -0,0 +1,97 @@ +// RUN: torch-mlir-opt %s --split-input-file --torch-fuse-quantized-ops | FileCheck %s + +// CHECK-LABEL: @mm +func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si8>) -> !torch.vtensor<[4, 4],f32> { + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %zp = torch.constant.int -128 + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32> + %16 = torch.aten.mm %7, %13 : !torch.vtensor<[4, 4],f32>, !torch.vtensor<[4, 4],f32> -> !torch.vtensor<[4, 4],f32> + + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF:.+]], %[[ONE]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF:.+]], %[[ZERO]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8> + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %[[QLHS]], %[[QRHS]] : !torch.vtensor<[4,4],!torch.qint8>, !torch.vtensor<[4,4],!torch.qint8> -> !torch.vtensor<[4,4],!torch.qint32> + // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[QUARTER]], %[[ZERO]] : !torch.vtensor<[4,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint32> + // CHECK: %[[OUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],f32> + return %16 : !torch.vtensor<[4, 4],f32> +} + +// ----- + +// CHECK-LABEL: @convolution_bias +func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %zp = torch.constant.int -128 + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32> + %14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list + %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list + %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> + + // CHECK: %[[DTYPE:.+]] = torch.constant.int 14 + // CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + return %16 : !torch.vtensor<[1,3,7,7],f32> +} + + +// ----- + +// CHECK-LABEL: @convolution_nobias +func.func @convolution_nobias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>) -> !torch.vtensor<[1,3,7,7],f32> { + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %zp = torch.constant.int -128 + %none = torch.constant.none + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32> + %14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list + %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list + %16 = torch.aten.convolution %7, %13, %none, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> + + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[NONE]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + return %16 : !torch.vtensor<[1,3,7,7],f32> +} diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 067c1a9b67f4..63aa1e3755a9 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -362,3 +362,16 @@ func.func @torch.permute$invalid_index_in_permutation (%arg0: !torch.vtensor<[1, return %3 : !torch.vtensor<[1,2,3],f32> } +// ----- + +#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + +// expected-error @+1 {{dimension-rank mismatch between encoding and tensor shape: 1 != 2}} +func.func @foo(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64],f32,#SV> { + return %arg0 : !torch.vtensor<[64,64],f32,#SV> +} + +// ----- + +// expected-error @+1 {{invalid sparsity encoding attribute}} +func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> diff --git a/test/Dialect/Torch/match-quantized-customs-ops.mlir b/test/Dialect/Torch/match-quantized-customs-ops.mlir new file mode 100644 index 000000000000..4196e688157f --- /dev/null +++ b/test/Dialect/Torch/match-quantized-customs-ops.mlir @@ -0,0 +1,42 @@ +// RUN: torch-mlir-opt --split-input-file --torch-match-quantized-custom-ops %s | FileCheck %s + +// CHECK-LABEL: func.func @quantize_per_tensor +func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.vtensor<[1,3,8,8],si8> { + %float = torch.constant.float 0.5 + %zp = torch.constant.int 17 + %min = torch.constant.int -128 + %max = torch.constant.int 127 + %dtype = torch.constant.int 1 + + // CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17 + // CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127 + // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[DTYPE]] : !torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],si8> + // CHECK: torch.aten.clamp %[[REPR]], %[[MIN]], %[[MAX]] + %0 = torch.operator "torch.quantized_decomposed.quantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],si8> + return %0 : !torch.vtensor<[1,3,8,8],si8> +} + +// ----- + +// CHECK-LABEL: func.func @dequantize_per_tensor +func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],f32> { + %float = torch.constant.float 0.5 + %zp = torch.constant.int 17 + %min = torch.constant.int -128 + %max = torch.constant.int 127 + %dtype = torch.constant.int 1 + + // CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17 + // CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127 + // CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8> + // CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CLAMP]], %[[SCALE]], %[[ZP]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.tensor %[[QINT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32> + %13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32> + return %13 : !torch.vtensor<[1,3,8,8],f32> +} diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 0feef563c3d5..54dc6bb058e5 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -1,5 +1,7 @@ // RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s +// CHECK: #[[$ENCODING:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + // CHECK-LABEL: func.func @torch.operator( func.func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor { // CHECK: torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor @@ -28,6 +30,10 @@ func.func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk> // CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32> func.func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32> +// CHECK: @tensor.sparse() -> !torch.vtensor<[64,64],f32,#[[$ENCODING]]> +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,#CSR> + // CHECK: @tuple.empty() -> !torch.tuple<> func.func private @tuple.empty() -> !torch.tuple<> // CHECK: @tuple.one_element() -> !torch.tuple diff --git a/test/Dialect/Torch/reduce-op-variants.mlir b/test/Dialect/Torch/reduce-op-variants.mlir index 1122a7b3f844..94bec8aa2160 100644 --- a/test/Dialect/Torch/reduce-op-variants.mlir +++ b/test/Dialect/Torch/reduce-op-variants.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s +// RUN: torch-mlir-opt -torch-reduce-op-variants --split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @convert_to_value_semantic_tensors( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { @@ -11,6 +11,8 @@ func.func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !t return %0 : !torch.tensor<[],f32> } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list( // CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor, // CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor { @@ -40,6 +42,8 @@ func.func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !t return %ret : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional( // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>, // CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool, @@ -83,6 +87,8 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor, return %ret: !torch.tensor } +// ----- + // CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant( // CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { @@ -106,6 +112,7 @@ func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2] %0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32> return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32> } +// ----- // CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor { // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32> @@ -117,6 +124,8 @@ func.func @torch.tensor.literal() -> !torch.tensor { return %0 : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list( // CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>, // CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { @@ -134,6 +143,8 @@ func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor< return %ret : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors( // CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>, // CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>, @@ -155,6 +166,8 @@ func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(%se return %ret : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @torch.aten.bernoulli_.float( // CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { // CHECK: %[[GENERATOR:.*]] = torch.constant.none @@ -171,3 +184,22 @@ func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { %ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor return %ret : !torch.tensor } + +// ----- + +// CHECK-LABEL: func.func @scaled_dot_product_flash_attention_for_cpu +// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG1:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG2:.+]]: !torch.vtensor<[1,1,5,5],f32> +// CHECK: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 +// CHECK: %[[FALSE:.+]] = torch.constant.bool false +// CHECK: %[[NONE0:.+]] = torch.constant.none +// CHECK: %[[NONE1:.+]] = torch.constant.none +// CHECK: %[[ATTEN:.+]] = torch.aten.scaled_dot_product_attention %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[NONE0]], %[[ZERO]], %[[FALSE]], %[[NONE1]] +// CHECK: return %[[ATTEN]] +func.func @scaled_dot_product_flash_attention_for_cpu(%arg0: !torch.vtensor<[1,1,5,5],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> { + %float0.000000e00 = torch.constant.float 0.000000e+00 + %false = torch.constant.bool false + %none = torch.constant.none + %none_0 = torch.constant.none + %0:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%arg0, %arg1, %arg2, %float0.000000e00, %false, %none, %none_0) : (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5],f32>) + return %0#0 : !torch.vtensor<[1,1,5,5],f32> +} diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index 10a65a527873..b7e7cf17ba0e 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -105,9 +105,9 @@ func.func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !tor // CHECK-LABEL: func.func @fully_unroll_prim_loop$unroll( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.list) -> !torch.vtensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: torch.shape.calculate.yield %[[ARG0]] : !torch.vtensor // CHECK: } shapes { @@ -316,7 +316,7 @@ func.func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch. // CHECK: } else { // CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list // CHECK: } - // .... and this one don't have the same object identity, but should! + // .... and this one don't have the same object identity, but should! // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_9:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list) { // CHECK: torch.prim.If.yield %[[VAL_8]] : !torch.list @@ -375,8 +375,8 @@ func.func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch. // missing. // CHECK-LABEL: func.func @basic_integration( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor { -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?],unk> -> !torch.vtensor<[?,?],unk> // CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor<[?,?],unk> @@ -410,8 +410,8 @@ func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor // CHECK-LABEL: func.func @fold_prim_unchecked_cast_op( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor { -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.shape.calculate { // CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[VAL_0]] : !torch.vtensor to !torch.vtensor<[?,?],unk> // CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[?,?],unk> diff --git a/test/Dialect/Torch/torch-nary-canonicalize.mlir b/test/Dialect/Torch/torch-nary-canonicalize.mlir new file mode 100644 index 000000000000..b0d22e35da9c --- /dev/null +++ b/test/Dialect/Torch/torch-nary-canonicalize.mlir @@ -0,0 +1,143 @@ +// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: @fold_aten_add_splat_int +func.func @fold_aten_add_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_int_mismatch +func.func @fold_aten_add_splat_int_mismatch() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi32>) : !torch.vtensor<[4],si32> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si32>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_float +func.func @fold_aten_add_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_float_mismatch +func.func @fold_aten_add_splat_float_mismatch() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf64>) : !torch.vtensor<[4],f64> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f64>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr0_int +func.func @fold_aten_add_arr0_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<[28, 29, 30, 31]> : tensor<4xsi64>) + %cst_7 = torch.vtensor.literal(dense<[6,7,8,9]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr1_int +func.func @fold_aten_add_arr1_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<[27, 29, 31, 33]> : tensor<4xsi64>) + %int2 = torch.constant.int 2 + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<[10,11,12,13]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr0_float +func.func @fold_aten_add_arr0_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<[2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<[6.0, 7.0, 8.0, 9.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr1_float +func.func @fold_aten_add_arr1_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<[2.700000e+01, 2.900000e+01, 3.100000e+01, 3.300000e+01]> : tensor<4xf32>) + %fp_2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<[10.0,11.0,12.0,13.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_sub_splat_int +func.func @fold_aten_sub_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<-15> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int_2 = torch.constant.int 2 + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.sub.Tensor %cst_7, %cst_11, %int_2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_sub_splat_float +func.func @fold_aten_sub_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<-1.500000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %fp_2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.sub.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_mul_splat_int +func.func @fold_aten_mul_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<77> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.mul.Tensor %cst_7, %cst_11: !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_mul_splat_float +func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<7.700000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} diff --git a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir index 4f72f24e8868..7aca3551cfc2 100644 --- a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir +++ b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-convert-custom-quant-op))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> diff --git a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index a16da0932640..57077a723ada 100644 --- a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-finalizing-backend-type-conversion -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-finalizing-backend-type-conversion))' -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s // This test is largely copied from `finalizing-bufferize` upstream, as it // covers the same scope. @@ -54,6 +54,20 @@ func.func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 { // ----- +// CHECK-LABEL: func.func @eliminate_attributes() +// CHECK-NOT: attributes +// CHECK-NOT: torch.onnx_meta +func.func @eliminate_attributes() attributes { + torch.onnx_meta.ir_version = 8 : si64, + torch.onnx_meta.opset_version = 17 : si64, + torch.onnx_meta.producer_name = "pytorch", + torch.onnx_meta.producer_version = "2.1.0" +} { + return +} + +// ----- + func.func @unable_to_convert_lone_buffer_cast() -> tensor { // expected-error @+1 {{failed to legalize operation 'test.source'}} %0 = "test.source"() : () -> !torch.vtensor<[],f32> diff --git a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir index 0ca64ae09397..8fa1a775b66d 100644 --- a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir +++ b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-unpack-quant-tensor))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @forward func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> { diff --git a/test/lit.cfg.py b/test/lit.cfg.py index a9753bf22719..4608dfb6c892 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -24,7 +24,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.py'] +config.suffixes = ['.mlir', '.py', '.runlit'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) diff --git a/test/python/compile.py b/test/python/compile.py new file mode 100644 index 000000000000..678a4137acf6 --- /dev/null +++ b/test/python/compile.py @@ -0,0 +1,33 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import gc +import sys +import torch +from torch_mlir import torchscript + + +def run_test(f): + print("TEST:", f.__name__, file=sys.stderr) + f() + gc.collect() + + +class TinyModel(torch.nn.Module): + def __init__(self): + super(TinyModel, self).__init__() + + self.linear = torch.nn.Linear(20, 30) + + def forward(self, x): + x = self.linear(x) + return x + + +@run_test +def test_enable_ir_printing(): + torchscript.compile(TinyModel(), + torch.ones(1, 3, 20, 20), + output_type="linalg-on-tensors", + enable_ir_printing=True) +# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) +# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} { diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py new file mode 100644 index 000000000000..a51032273999 --- /dev/null +++ b/test/python/fx_importer/basic_test.py @@ -0,0 +1,81 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# Requires torch>=2.3.0.dev20240307 +# UNSUPPORTED: true +# RUN: %PYTHON %s | FileCheck %s + +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_import_frozen_exported_program +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] +# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] +# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] +# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] +# CHECK: return %[[mul_p]] +# +# Validate dialect resources exist. +# CHECK: dialect_resources: +# CHECK-DAG: torch_tensor_1_4_torch.float32 +# CHECK-DAG: torch_tensor_3_1_torch.float32 +def test_import_frozen_exported_program(): + # Tests the basic structural premises of import_frozen_exported_program, + # namely that free tensors (buffers) and parameters are treated as + # literals and frozen. + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = fx.export_and_import(Basic(), torch.randn(3, 4)) + print(m) + + +@run +# CHECK-LABEL: test_import_frozen_exported_program_with_func_name +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +def test_import_frozen_exported_program_with_func_name(): + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net") + print(m) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py new file mode 100644 index 000000000000..40c633cfc778 --- /dev/null +++ b/test/python/fx_importer/sparse_test.py @@ -0,0 +1,370 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# Requires torch>=2.3.0.dev20240307 +# UNSUPPORTED: true +# RUN: %PYTHON %s | FileCheck %s + +from typing import Any, Callable, Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir.extras.fx_importer import SparsityMeta +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_d +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) + + +# All sparse layouts currently supported in torch.sparse. +SPARSE_LAYOUTS = [ + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, +] + + +def sparse_metadata(a: torch.Tensor) -> SparsityMeta: + """ + Returns a meta data tuple for the given sparse tensor. + + NOTE: this will be fully replaced by fx graph SparseTensorMetadata + """ + sparse_dim = a.sparse_dim() + dense_dim = a.dense_dim() + batch_dim = a.ndim - dense_dim - sparse_dim + blocksize = None + if a.layout is torch.sparse_coo: + return SparsityMeta( + a.layout, + batch_dim, + sparse_dim, + dense_dim, + blocksize, + a.indices().dtype, + a.indices().dtype, + ) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + if a.layout is torch.sparse_bsr: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] + return SparsityMeta( + a.layout, + batch_dim, + sparse_dim, + dense_dim, + blocksize, + a.crow_indices().dtype, + a.col_indices().dtype, + ) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + if a.layout is torch.sparse_bsc: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] + return SparsityMeta( + a.layout, + batch_dim, + sparse_dim, + dense_dim, + blocksize, + a.ccol_indices().dtype, + a.row_indices().dtype, + ) + else: + raise RuntimeError(f"Unsupported sparse layout for {a}") + + +def sparse_export( + f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None +) -> torch.export.ExportedProgram: + """ + This is a ***temporary*** wrapper around `torch.export.export` + that eventually should be removed and simply replaced by the + standard API for exporting traced graphs. + + But until issue + + https://github.com/pytorch/pytorch/pull/117907 + + is addressed, this wrapper provides support for the sparse + tensor types by first converting all operands to dense tensors, + building the traced graph as for the dense case, and then + annotation sparse parameters with their actual sparse layout + attributes. This temporary solution accelerates testing + torch-mlir with PyTorch sparse tensors until the issue is + resolved. + """ + # Convert all arguments to dense. + dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) + mask = [a.layout in SPARSE_LAYOUTS for a in args] + # Build the regular FX traced graph with only dense arguments + # (the current version would crash otherwise, see issue above). + prog = torch.export.export(f, dargs, kwargs) + # Annotate sparse arguments in the graph. Note that we currently + # only account for sparsity defined by the user inputs to the model. + # TODO: support sparsity in model parameters (weights, biases) + # TODO: propagate sparsity into the layers + specs = prog.graph_signature.input_specs + alen = len(specs) + k = 0 + for i, node in enumerate(prog.graph.nodes): + if i >= alen: + break + spec = specs[i] + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + if mask[k]: + node.meta["sparsity"] = sparse_metadata(args[k]) + k = k + 1 + return prog + + +def export_and_import(f, *args, **kwargs): + """This method implements Stella's importer, stripped down to essentials.""" + context = ir.Context() + torch_d.register_dialect(context) + fx_importer = FxImporter(context=context) + prog = sparse_export(f, args, kwargs) + fx_importer.import_frozen_program(prog) + return fx_importer.module + + +def sparse_jit(f, *args, **kwargs): + """This method compiles and runs the given callable using linalg backend.""" + # Import module and lower into Linalg IR. + module = export_and_import(f, *args, *kwargs) + run_pipeline_with_repro_report( + module, + ( + "builtin.module(" + "func.func(torch-decompose-complex-ops)," + "torch-backend-to-linalg-on-tensors-backend-pipeline)" + ), + "Lowering TorchFX IR -> Linalg IR", + enable_ir_printing=False, + ) + # Compile with reference Linalg backend. + backend = RefBackendLinalgOnTensorsBackend() + compiled = backend.compile(module) + invoker = backend.load(compiled) + # Prepare input parameters. Sparse input tensors are split into + # their composite tensors. All PyTorch tensors are converted + # to their backing numpy arrays. + # + # TODO: sparse output tensors + # + xargs = [] + for a in args: + if a.layout is torch.sparse_coo: + xargs.append(a.values().numpy()) + # Construct the additional position array required by MLIR with data + # array([0, nnz]). + xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy()) + # Transform a tensor into [tensor x ndim] to conform + # MLIR SoA COO representation. + for idx in a.indices(): + xargs.append(idx.numpy()) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + xargs.append(a.values().numpy()) + xargs.append(a.crow_indices().numpy()) + xargs.append(a.col_indices().numpy()) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + xargs.append(a.values().numpy()) + xargs.append(a.ccol_indices().numpy()) + xargs.append(a.row_indices().numpy()) + else: + xargs.append(a.numpy()) + # Invoke. + return invoker.main(*xargs) + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_sparse_sum +# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> { +# CHECK: %[[N:.*]] = torch.constant.none +# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32> +# CHECK: return %[[R]] : !torch.vtensor<[],f32> +# CHECK: } +# +# CHECK: torch.sparse = tensor(4096.) +# CHECK: torch.mlir = 4096.0 +# +def test_sparse_sum(): + class SumNet(torch.nn.Module): + def __init__(self): + super(SumNet, self).__init__() + + def forward(self, x): + return x.sum() + + net = SumNet() + dense_input = torch.ones(64, 64) + sparse_input = dense_input.to_sparse_csr() + m = export_and_import(net, sparse_input) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input) + res2 = sparse_jit(net, sparse_input) + print("torch.sparse =", res1) + print("torch.mlir =", res2) + + +@run +# CHECK-LABEL: test_sparse_SpMV +# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[10,10],f32,#[[$BSR]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> { +# CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> +# CHECK: return %[[R]] : !torch.vtensor<[10],f32> +# CHECK: } +# +# CHECK: torch.sparse = tensor([55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]) +# CHECK: torch.mlir = [55. 55. 55. 55. 55. 55. 55. 55. 55. 55.] +# +def test_sparse_SpMV(): + class SpMVNet(torch.nn.Module): + def __init__(self): + super(SpMVNet, self).__init__() + + def forward(self, x, v): + return torch.mv(x, v) + + net = SpMVNet() + dense_vector = torch.arange(1, 11, dtype=torch.float32) + dense_input = torch.ones(10, 10) + sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2)) + m = export_and_import(net, sparse_input, dense_vector) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input, dense_vector) + res2 = sparse_jit(net, sparse_input, dense_vector) + print("torch.sparse =", res1) + print("torch.mlir =", res2) + + +@run +# CHECK-LABEL: test_sparse_SpMM +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { +# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], +# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], +# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) +# CHECK: torch.mlir +# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] +# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] +# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} +# +def test_sparse_SpMM(): + class MatMulNet(torch.nn.Module): + def __init__(self): + super(MatMulNet, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + net = MatMulNet() + dense_input = torch.ones(8, 8) + sparse_input = dense_input.to_sparse_coo() + m = export_and_import(net, sparse_input, dense_input) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input, dense_input) + res2 = sparse_jit(net, sparse_input, dense_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) + + +@run +# CHECK-LABEL: test_sparse_eltwise +# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> { +# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> +# CHECK: } +# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32> { +# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), +# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, +# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), +# CHECK: values=tensor({{\[}}[ -1., -2.], +# CHECK: [ -3., -4.], +# ... +# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, +# CHECK: layout=torch.sparse_csr) +# CHECK: torch.mlir +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} +# +def test_sparse_eltwise(): + class EltNet(torch.nn.Module): + def __init__(self): + super(EltNet, self).__init__() + + def forward(self, x): + return -x + + net = EltNet() + dense_input = torch.reshape( + torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2) + ) + + # This yields a **batched** CSR. + sparse_input = dense_input.to_sparse_csr(dense_dim=0) + m = export_and_import(net, sparse_input) + print(m) + + # This yields a plain CSR with dense **sub**tensor + sparse_input = dense_input.to_sparse_csr(dense_dim=1) + m = export_and_import(net, sparse_input) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + # + # TODO: note several issues that need to be fixed + # (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result + # (2) for dense_dim=0, this will need a dense(batched) property + sparse_input = dense_input.to_sparse_csr(dense_dim=1) + res1 = net(sparse_input) + res2 = sparse_jit(net, sparse_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) diff --git a/test/python/fx_importer/v2.3/lit.local.cfg b/test/python/fx_importer/v2.3/lit.local.cfg new file mode 100644 index 000000000000..b10b239f8b3a --- /dev/null +++ b/test/python/fx_importer/v2.3/lit.local.cfg @@ -0,0 +1,9 @@ +config.unsupported = True + +try: + import torch + if torch.__version__ >= "2.3.0.dev20240207": + print("Enabling Torch v2.3+ tests") + config.unsupported = False +except ModuleNotFoundError: + ... diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py new file mode 100644 index 000000000000..ef293b8cb134 --- /dev/null +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -0,0 +1,163 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir import fx + +from torch_mlir.ir import ( + Operation, +) + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# Tests that constants and parameters work generally with the mutation path. +# This doesn't do mutation but ensures that the basics remain functional. +# CHECK-LABEL: test_import_frozen_exported_program +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] +# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] +# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] +# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] +# CHECK: return %[[mul_p]] +def test_import_frozen_exported_program(): + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = fx.export_and_import( + Basic(), torch.randn(3, 4), experimental_support_mutation=True + ) + print(m) + m.operation.verify() + + +@run +# CHECK-LABEL: test_user_input_mutate +# CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.tensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[arg1_copy:.+]] = torch.copy.to_vtensor %arg1 : !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[arg1_mul:.+]] = torch.aten.mul.Tensor %[[arg1_copy]], %arg0 +# CHECK-DAG: torch.overwrite.tensor.contents %[[arg1_mul]] overwrites %arg1 +# CHECK-DAG: %[[arg0_mul:.+]] = torch.aten.mul.Tensor %arg0, %[[arg1_mul]] +# CHECK: return %[[arg0_mul]] +def test_user_input_mutate(): + class Basic(nn.Module): + def forward(self, x, y): + y.mul_(x) + return x * y + + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + torch.randn(3, 4), + experimental_support_mutation=True, + ) + print(m) + m.operation.verify() + + +@run +# CHECK-LABEL: test_frozen_buffer +# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal +# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0 +# CHECK: return %[[mul]] +def test_frozen_buffer(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4)) + + def forward(self, x): + return x * self.buffer + + m = fx.export_and_import( + Basic(), torch.randn(3, 4), experimental_support_mutation=True + ) + print(m) + m.operation.verify() + + +class ExternalBufferHooks(fx.FxImporterHooks): + def prepare_module(self, module_op: Operation): + module_op.context.allow_unregistered_dialects = True + + def resolve_input(self, gni, value, info): + return Operation.create( + "my_dialect.import_buffer", results=[info.ir_type] + ).result + + +@run +# CHECK-LABEL: test_mutable_buffer +# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() : () -> !torch.tensor<[3,4],f32> +# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0 +# CHECK: torch.overwrite.tensor.contents %[[mul]] overwrites %[[buffer]] +# CHECK: return %arg0 +def test_mutable_buffer(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4)) + + def forward(self, x): + self.buffer.mul_(x) + return x + + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + experimental_support_mutation=True, + hooks=ExternalBufferHooks(), + ) + print(m) + m.operation.verify() + + +@run +# CHECK-LABEL: test_mutable_buffer_not_supported_from_literal +# CHECK: ERROR: Cannot import {{.*}} as a literal because it is mutable +def test_mutable_buffer_not_supported_from_literal(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4)) + + def forward(self, x): + self.buffer.mul_(x) + return x + + try: + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + experimental_support_mutation=True, + ) + except ValueError as e: + print("ERROR:", e) diff --git a/test/python/lit.local.cfg b/test/python/lit.local.cfg new file mode 100644 index 000000000000..4cfe04325d94 --- /dev/null +++ b/test/python/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_bindings_python: + config.unsupported = True diff --git a/test/python/onnx_importer/.gitignore b/test/python/onnx_importer/.gitignore new file mode 100644 index 000000000000..ea1472ec1f38 --- /dev/null +++ b/test/python/onnx_importer/.gitignore @@ -0,0 +1 @@ +output/ diff --git a/test/python/onnx_importer/LeakyReLU.onnx b/test/python/onnx_importer/LeakyReLU.onnx new file mode 100644 index 000000000000..f76bccbce92a --- /dev/null +++ b/test/python/onnx_importer/LeakyReLU.onnx @@ -0,0 +1,15 @@ +pytorch0.3:h +" +01" LeakyRelu* +alpha +×#< torch-jit-exportZ +0 + + + +b +1 + + + +B \ No newline at end of file diff --git a/test/python/onnx_importer/_torch_mlir_config.py b/test/python/onnx_importer/_torch_mlir_config.py new file mode 100644 index 000000000000..fdcf61cb81d7 --- /dev/null +++ b/test/python/onnx_importer/_torch_mlir_config.py @@ -0,0 +1,21 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s +# Requires python>=3.10 +# UNSUPPORTED: true + +"""This file exists so that the tests can find/configure torch_mlir. + +It allows the test file to be standalone and used verbatim in other +projects (i.e. by just providing this file on the side). +""" + +from torch_mlir import ir +from torch_mlir.extras import onnx_importer + +def configure_context(context): + from torch_mlir.dialects import torch as torch_d + torch_d.register_dialect(context) diff --git a/test/python/onnx_importer/command_line_test.py b/test/python/onnx_importer/command_line_test.py new file mode 100644 index 000000000000..f379376f0a4d --- /dev/null +++ b/test/python/onnx_importer/command_line_test.py @@ -0,0 +1,146 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# Requires onnx==1.15.0 +# UNSUPPORTED: true +# RUN: %PYTHON %s --output %t + +from pathlib import Path + +import logging +import shutil +import sys +import subprocess +import unittest +import unittest.mock + +import onnx + +from torch_mlir.tools.import_onnx import __main__ + +# For ONNX models + +import numpy +from onnx import numpy_helper, TensorProto +from onnx.helper import ( + make_model, make_node, make_graph, + make_tensor_value_info) +from onnx.external_data_helper import convert_model_to_external_data +from onnx.checker import check_model + +# Accept the output path on the command line or default to a sibling +# to this file. We have to pop this off explicitly or else unittest +# won't understand. +if len(sys.argv) > 1 and sys.argv[1] == "--output": + OUTPUT_PATH = Path(sys.argv[2]) + del sys.argv[1:3] +else: + OUTPUT_PATH = Path(__file__).resolve().parent / "output" + +OUTPUT_PATH.mkdir(parents=True, exist_ok=True) + + +def const_model() -> onnx.ModelProto: + # Note: data_path must be relative to model_file + + const = make_node( + 'Constant', [], ['c_shape'], 'const', + value=numpy_helper.from_array(numpy.array([4], dtype=numpy.int64))) + cofshape = make_node( + 'ConstantOfShape', ['c_shape'], ['c_out'], 'cofshape', + value=numpy_helper.from_array(numpy.array([1], dtype=numpy.int64))) + + outval = make_tensor_value_info('c_out', TensorProto.INT64, [None]) + graph = make_graph([const, cofshape], 'constgraph', [], [outval]) + + onnx_model = make_model(graph) + check_model(onnx_model) + return onnx_model + + +def linear_model() -> onnx.ModelProto: + # initializers + k_dim = 32 + value = numpy.arange(k_dim).reshape([k_dim, 1]) + value = numpy.asarray(value, dtype=numpy.float32) + A = numpy_helper.from_array(value, name='A') + + value = numpy.array([0.4], dtype=numpy.float32).reshape([1, 1]) + C = numpy_helper.from_array(value, name='C') + + # the part which does not change + X = make_tensor_value_info('X', TensorProto.FLOAT, [1, k_dim]) + Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None, None]) + node1 = make_node('MatMul', ['X', 'A'], ['AX']) + node2 = make_node('Add', ['AX', 'C'], ['Y']) + graph = make_graph([node1, node2], 'lr', [X], [Y], [A, C]) + onnx_model = make_model(graph) + check_model(onnx_model) + return onnx_model + + +ALL_MODELS = [ + const_model, + linear_model +] + + +class CommandLineTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.test_dir = OUTPUT_PATH / "command-line" + shutil.rmtree(cls.test_dir, ignore_errors=True) + cls.test_dir.mkdir(parents=True, exist_ok=True) + + def get_run_path(self, model_name: str) -> Path: + run_path = CommandLineTest.test_dir / model_name + run_path.mkdir(exist_ok=True) + return run_path + + def run_model_intern(self, onnx_model: onnx.ModelProto, model_name: str): + run_path = self.get_run_path(model_name) + model_file = run_path / f"{model_name}-i.onnx" + mlir_file = run_path / f"{model_name}-i.torch.mlir" + onnx.save(onnx_model, model_file) + args = __main__.parse_arguments([ + str(model_file), "-o", str(mlir_file)]) + __main__.main(args) + + def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str): + run_path = self.get_run_path(model_name) + model_file = run_path / f"{model_name}-e.onnx" + mlir_file = run_path / f"{model_name}-e.torch.mlir" + data_dir_name = f"{model_name}-data" + model_data_dir = run_path / data_dir_name + model_data_dir.mkdir(exist_ok=True) + convert_model_to_external_data( + onnx_model, all_tensors_to_one_file=True, + location=data_dir_name + "/data.bin", + size_threshold=48, + convert_attribute=True) + onnx.save(onnx_model, model_file) + temp_dir = run_path / "temp" + temp_dir.mkdir(exist_ok=True) + args = __main__.parse_arguments([ + str(model_file), "-o", str(mlir_file), "--keep-temps", "--temp-dir", + str(temp_dir), "--data-dir", str(run_path)]) + __main__.main(args) + + def test_all(self): + for model_func in ALL_MODELS: + model_name = model_func.__name__ + model = model_func() + with self.subTest(f"model {model_name}", model_name=model_name): + with self.subTest("Internal data"): + self.run_model_intern(model, model_name) + with self.subTest("External data"): + self.run_model_extern(model, model_name) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/test/python/onnx_importer/import_onnx_tool.runlit b/test/python/onnx_importer/import_onnx_tool.runlit new file mode 100644 index 000000000000..2f170c739896 --- /dev/null +++ b/test/python/onnx_importer/import_onnx_tool.runlit @@ -0,0 +1,5 @@ +# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true + +# CHECK: torch.operator "onnx.LeakyRelu" diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py new file mode 100644 index 000000000000..533ffbc45d70 --- /dev/null +++ b/test/python/onnx_importer/import_smoke_test.py @@ -0,0 +1,143 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s --output %t +# Requires python>=3.10 +# UNSUPPORTED: true + +from glob import glob +from pathlib import Path + +import logging +import sys +import unittest + +import onnx + +from _torch_mlir_config import ( + configure_context, + ir, + onnx_importer, +) + +# Accept the output path on the command line or default to a sibling +# to this file. We have to pop this off explicitly or else unittest +# won't understand. +if len(sys.argv) > 1 and sys.argv[1] == "--output": + OUTPUT_PATH = Path(sys.argv[2]) + del sys.argv[1:3] +else: + OUTPUT_PATH = Path(__file__).resolve().parent / "output" + + +# TODO: Add some verification and overrides. For now, just use the +# onnx package install for onnx test files, since they were nice +# enough to include the test suite in the deployable. +import onnx.backend.test.data + +ONNX_TEST_DATA_DIR = Path(onnx.backend.test.__file__).resolve().parent / "data" +print(f"ONNX Test Data Dir: {ONNX_TEST_DATA_DIR}") +ONNX_REL_PATHS = glob(f"**/*.onnx", root_dir=ONNX_TEST_DATA_DIR, recursive=True) + +OUTPUT_PATH.mkdir(parents=True, exist_ok=True) + +TEST_CAST_XFAILS = [ + "node_test_ai_onnx_ml_label_encoder_tensor_mapping_model", + "node_test_if_opt_model", +] + +class ImportSmokeTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.unexpected_failure_count = 0 + ImportSmokeTest.actual_failures = [] + + @classmethod + def tearDownClass(cls): + if cls.unexpected_failure_count: + # Print a helpful message with copy-paste XFAIL def. + failure_report_path = OUTPUT_PATH / "import_smoke_test_report.txt" + print( + "Unexpected failures. Writing copy/paste report to:", + failure_report_path, + ) + with open(failure_report_path, "wt") as f: + lines = [f' "{s}",' for s in ImportSmokeTest.actual_failures] + print( + f"Unexpected failures in the following. Copy/paste to update `TEST_CAST_XFAILS`:", + file=f, + ) + print(f"TEST_CAST_XFAILS = [", file=f) + [print(l, file=f) for l in lines] + print(f"]", file=f) + + ImportSmokeTest.actual_failures.clear() + + def load_onnx_model(self, file_path: Path) -> onnx.ModelProto: + raw_model = onnx.load(file_path) + try: + inferred_model = onnx.shape_inference.infer_shapes(raw_model) + except onnx.onnx_cpp2py_export.shape_inference.InferenceError as e: + print("WARNING: Shape inference failure (skipping test):", e) + self.skipTest(reason="shape inference failure") + + # inferred_model = raw_model + return inferred_model + + def run_import_test(self, norm_name: str, rel_path: str): + context = ir.Context() + configure_context(context) + + model_info = onnx_importer.ModelInfo( + self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path), + ) + m = model_info.create_module(context=context).operation + try: + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) + imp.import_all() + m.verify() + finally: + # Use a ".txt" extension to avoid lit test discovery. + with open(OUTPUT_PATH / f"{norm_name}.mlir", "wt") as f: + print(m.get_asm(), file=f) + + def testExists(self): + # We expect a lot of test cases. Die if not the case (i.e. if paths change + # or something). + self.assertGreater(len(ONNX_REL_PATHS), 10) + + +# Generate test methods for each onnx file. +for _rel_path in ONNX_REL_PATHS: + + def attach_test(rel_path): + norm_name = rel_path.removesuffix(".onnx").replace("/", "_") + + def test_method(self: ImportSmokeTest): + try: + self.run_import_test(norm_name, rel_path) + except onnx_importer.OnnxImportError as e: + # All legitimate failures should be caught and reported + # as an OnnxImportError. + ImportSmokeTest.actual_failures.append(norm_name) + if norm_name not in TEST_CAST_XFAILS: + ImportSmokeTest.unexpected_failure_count += 1 + raise e + + test_method.__name__ = f"test_{norm_name}" + + if norm_name in TEST_CAST_XFAILS: + test_method = unittest.expectedFailure(test_method) + + setattr(ImportSmokeTest, test_method.__name__, test_method) + + attach_test(_rel_path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/test/python/onnx_importer/lit.local.cfg b/test/python/onnx_importer/lit.local.cfg new file mode 100644 index 000000000000..8e0adb7c1c49 --- /dev/null +++ b/test/python/onnx_importer/lit.local.cfg @@ -0,0 +1,5 @@ +try: + import onnx +except ModuleNotFoundError: + print("Skipping onnx tests.. no onnx") + config.unsupported = True diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index ff2205b4ef48..2bae4d4fd6b3 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -4,4 +4,4 @@ # release page, and we use this page as an additional source for the wheels. -f https://xilinx.github.io/torch-mlir/package-index/ --pre -torchvision==0.18.0.dev20240108 +torchvision==0.18.0.dev20240307 diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 2a9edaac503c..9edb488a0939 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -133,6 +133,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:TransformUtils", ], ) @@ -420,6 +421,29 @@ cc_library( ], ) +cc_library( + name = "TorchMLIRTorchToTensor", + srcs = glob([ + "lib/Conversion/*.h", + "lib/Conversion/TorchToTensor/*.cpp", + ]), + hdrs = glob([ + "include/torch-mlir/Conversion/TorchToTensor/*.h", + ]), + strip_include_prefix = "include", + deps = [ + ":TorchMLIRConversionPassesIncGen", + ":TorchMLIRConversionUtils", + ":TorchMLIRTorchBackendTypeConversion", + ":TorchMLIRTorchConversionDialect", + ":TorchMLIRTorchDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TensorDialect", + ], +) + cc_library( name = "TorchMLIRTorchConversionToMLProgram", srcs = glob([ @@ -515,6 +539,7 @@ cc_library( ":TorchMLIRTorchToSCF", ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", + ":TorchMLIRTorchToTensor", ":TorchMLIRTorchToTosa", ], ) @@ -539,6 +564,7 @@ cc_library( ":TorchMLIRTorchToSCF", ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", + ":TorchMLIRTorchToTensor", ":TorchMLIRTorchToTosa", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", @@ -841,7 +867,10 @@ cc_library( hdrs = [ "include/torch-mlir/InitAll.h", ], - copts = ["-DTORCH_MLIR_ENABLE_REFBACKEND"], + copts = [ + "-DTORCH_MLIR_ENABLE_REFBACKEND", + "-DTORCH_MLIR_ENABLE_STABLEHLO", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPasses", @@ -856,6 +885,8 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@stablehlo//:linalg_passes", + "@stablehlo//:stablehlo_passes", ], )