diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index d0d11ad5a6eb..b3cb3b8fb165 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -10,7 +10,11 @@ on: # any in-progress jobs in the same github workflow and github # ref (e.g. refs/heads/main or refs/pull//merge). concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} cancel-in-progress: true diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index ac92999f1490..b6726cf90b52 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -11,7 +11,11 @@ on: # any in-progress jobs in the same github workflow and github # ref (e.g. refs/heads/main or refs/pull//merge). concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} cancel-in-progress: true @@ -24,17 +28,14 @@ jobs: strategy: fail-fast: true matrix: - os-arch: [ubuntu-x86_64] # macos-arm64, windows-x86_64 - llvm-build: [in-tree] # out-of-tree - torch-binary: [ON] # OFF + os-arch: [ubuntu-x86_64] #, macos-arm64, windows-x86_64] + llvm-build: [in-tree] #, out-of-tree] + torch-binary: [ON] torch-version: [nightly, stable] exclude: - # Exclude llvm in-tree and pytorch source - - llvm-build: in-tree - torch-binary: OFF - # Exclude llvm out-of-tree and pytorch binary + # Exclude llvm out-of-tree and pytorch stable (to save resources) - llvm-build: out-of-tree - torch-binary: ON + torch-version: stable # Exclude macos-arm64 and llvm out-of-tree altogether - os-arch: macos-arm64 llvm-build: out-of-tree @@ -44,9 +45,6 @@ jobs: llvm-build: out-of-tree - os-arch: windows-x86_64 torch-version: stable - # For PyTorch stable builds, we don't build PyTorch from source - - torch-version: stable - torch-binary: OFF include: # Specify OS versions - os-arch: ubuntu-x86_64 diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index db62cdfc06c8..278590ef3511 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -23,11 +23,6 @@ jobs: package: [ torch-mlir ] py_version: [ cp38-cp38, cp310-cp310 ] # cp311-cp311 torch-version: [stable] # nightly - exclude: - - package: torch-mlir-core - py_version: cp38-cp38 - - package: torch-mlir-core - py_version: cp310-cp310 steps: @@ -99,7 +94,7 @@ jobs: runs-on: linux-arm64 strategy: matrix: - package: [ torch-mlir, torch-mlir-core ] + package: [ torch-mlir ] py_version: [ cp311-cp311 ] steps: @@ -169,7 +164,7 @@ jobs: runs-on: macos-latest strategy: matrix: - package: [ torch-mlir, torch-mlir-core ] + package: [ torch-mlir ] steps: - name: Get torch-mlir uses: actions/checkout@v3 @@ -230,7 +225,7 @@ jobs: runs-on: windows-latest strategy: matrix: - package: [ torch-mlir, torch-mlir-core ] + package: [ torch-mlir ] steps: - name: Get torch-mlir uses: actions/checkout@v3 @@ -246,14 +241,8 @@ jobs: - name: Build Python wheels and smoke test. shell: pwsh run: | - if ( "${{ matrix.package }}" -eq "torch-mlir-core" ) - { - $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='0' - $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='1' - } else { - $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1' - $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0' - } + $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1' + $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0' $env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}' ./build_tools/python_deploy/build_windows.ps1 diff --git a/.gitignore b/.gitignore index 6b76bc3eae05..5c407428929c 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,7 @@ __pycache__ bazel-* # Autogenerated files -/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated +/projects/ltc/csrc/base_lazy_backend/generated #Docker builds build_oot/ diff --git a/CMakeLists.txt b/CMakeLists.txt index cf33ccac1400..ccbe7ccb3a98 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,8 @@ set(CMAKE_CXX_STANDARD 17) # Project options #------------------------------------------------------------------------------- +option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON) + option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON) if(TORCH_MLIR_ENABLE_REFBACKEND) add_definitions(-DTORCH_MLIR_ENABLE_REFBACKEND) @@ -149,10 +151,12 @@ endfunction() # Configure CMake. list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules) list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/build_tools/cmake) include(TableGen) include(AddLLVM) include(AddMLIR) +include(AddMLIRPython) ################################################################################ # Setup python. @@ -231,6 +235,4 @@ endif() # Sub-projects #------------------------------------------------------------------------------- -if(TORCH_MLIR_ENABLE_PROJECT_PT1) - add_subdirectory(projects/pt1) -endif() +add_subdirectory(projects) diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 02ac0eff09d9..40a64c1c1c2b 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -29,7 +29,6 @@ TORCH_INCLUDE_DIR = TORCH_DIR TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent -TORCH_MLIR_PT1_DIR = TORCH_MLIR_DIR / "projects" / "pt1" def reindent(text, prefix=""): return indent(dedent(text), prefix) @@ -114,12 +113,12 @@ def __init__(self, binary_dir): self.binary_dir = Path(binary_dir) assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}" self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml") - self.backend_path = TORCH_MLIR_PT1_DIR.joinpath( - "python", "torch_mlir", "csrc", "base_lazy_backend" + self.backend_path = TORCH_MLIR_DIR.joinpath( + "projects", "ltc", "csrc", "base_lazy_backend" ) assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}" self.generated_path = self.binary_dir.joinpath( - "projects", "pt1", "python", "torch_mlir", "csrc", "base_lazy_backend", "generated" + "projects", "ltc", "csrc", "base_lazy_backend", "generated" ) self.generated_path.mkdir(parents=True, exist_ok=True) @@ -415,7 +414,7 @@ def extract_signatures(text): // for ops that dont have a corresponding structured kernel or shape definition #include "shape_inference.h" - #include "torch_mlir/csrc/base_lazy_backend/utils/exception.h" + #include "base_lazy_backend/utils/exception.h" namespace torch {{ namespace lazy {{ {} @@ -467,7 +466,7 @@ def gen_fallback_code(*args, **kwargs): node_base="torch::lazy::TorchMlirNode", node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")), tensor_class=self.tensor_class, - tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h", + tensor_class_hdr="base_lazy_backend/tensor.h", create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor", shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")), lazy_ir_generator=GenMlirLazyIr, diff --git a/projects/pt1/python/torch_mlir/cmake/modules/TorchMLIRPyTorch.cmake b/build_tools/cmake/TorchMLIRPyTorch.cmake similarity index 100% rename from projects/pt1/python/torch_mlir/cmake/modules/TorchMLIRPyTorch.cmake rename to build_tools/cmake/TorchMLIRPyTorch.cmake diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index b90bfbdc7418..3df3dfb4f453 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -364,9 +364,9 @@ function setup_venv() { function build_out_of_tree() { local torch_from_bin="$1" local python_version="$2" - echo ":::: Build out-of-tree Torch from binary: $torch_from_bin with Python: $python_version" - local torch_version="$3" + echo ":::: Build out-of-tree Torch from binary: $torch_from_bin with Python: $python_version ($torch_version)" + local enable_ltc="ON" if [[ "${torch_version}" == "stable" ]] then diff --git a/build_tools/update_abstract_interp_lib.sh b/build_tools/update_abstract_interp_lib.sh index d33c69536850..cb44a4e8b27c 100755 --- a/build_tools/update_abstract_interp_lib.sh +++ b/build_tools/update_abstract_interp_lib.sh @@ -42,6 +42,6 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then fi PYTHONPATH="${pypath}" python \ - -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.abstract_interp_lib_gen \ + -m torch_mlir.jit_ir_importer.build_tools.abstract_interp_lib_gen \ --pytorch_op_extensions=${ext_module:-""} \ --torch_transforms_cpp_dir="${torch_transforms_cpp_dir}" diff --git a/build_tools/update_torch_ods.sh b/build_tools/update_torch_ods.sh index e0564a62dff8..cb0599f16f10 100755 --- a/build_tools/update_torch_ods.sh +++ b/build_tools/update_torch_ods.sh @@ -43,7 +43,7 @@ fi set +u PYTHONPATH="${PYTHONPATH}:${pypath}" python \ - -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \ + -m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen \ --torch_ir_include_dir="${torch_ir_include_dir}" \ --pytorch_op_extensions="${ext_module}" \ --debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt" diff --git a/build_tools/write_env_file.sh b/build_tools/write_env_file.sh index 05179c56a07c..8f3c9a59357f 100755 --- a/build_tools/write_env_file.sh +++ b/build_tools/write_env_file.sh @@ -13,7 +13,7 @@ portable_realpath() { td="$(portable_realpath "$(dirname "$0")"/..)" build_dir="$(portable_realpath "${TORCH_MLIR_BUILD_DIR:-$td/build}")" -python_packages_dir="$build_dir/tools/torch-mlir/python_packages" +python_packages_dir="$build_dir/python_packages" write_env_file() { echo "Updating $build_dir/.env file" diff --git a/docs/Torch-ops-E2E-implementation.md b/docs/Torch-ops-E2E-implementation.md index 153246f375b2..53031c9ce1f4 100644 --- a/docs/Torch-ops-E2E-implementation.md +++ b/docs/Torch-ops-E2E-implementation.md @@ -17,7 +17,7 @@ The end-to-end test is important to check the correctness of the other steps. ### Step 2. Update ods -Update [torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/main/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py) with the new op and run [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh) to generate the ods. Running `update_torch_ods.sh` would dump all the operators with schema into `JITOperatorRegistryDump.txt`. It’s convenient to look for ops signatures and operands names in this file. +Update [torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/main/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py) with the new op and run [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh) to generate the ods. Running `update_torch_ods.sh` would dump all the operators with schema into `JITOperatorRegistryDump.txt`. It’s convenient to look for ops signatures and operands names in this file. ### Step 3. Propagate types It’s essential to make sure the new op implements shape and dtype inference. See [abstract_interp_lib](https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md) for information on adding shape and dtype inference. diff --git a/docs/abstract_interp_lib.md b/docs/abstract_interp_lib.md index 14ffc2181a65..eb862e6bb40e 100644 --- a/docs/abstract_interp_lib.md +++ b/docs/abstract_interp_lib.md @@ -26,7 +26,7 @@ The two main use cases are: ## Architecture Functions are defined as TorchScript-able Python functions in -`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py`. +`python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py`. The signatures of the functions are systematically derived from Torch JIT operator registry. Most shape functions are expected to reuse the upstream helper functions diff --git a/docs/adding_an_e2e_test.md b/docs/adding_an_e2e_test.md index 1c961c5c19f8..7b74b904a0f8 100644 --- a/docs/adding_an_e2e_test.md +++ b/docs/adding_an_e2e_test.md @@ -5,7 +5,7 @@ Adding support for a Torch operator in Torch-MLIR should always be accompanied by at least one end-to-end test to make sure the implementation of the op matches the behavior of PyTorch. The tests live in the -`torch-mlir/python/torch_mlir_e2e_test/test_suite/` directory. When adding a new +`torch-mlir/projects/pt1/python/torch_mlir_e2e_test/test_suite` directory. When adding a new test, choose a file that best matches the op you're testing, and if there is no file that best matches add a new file for your op. @@ -87,7 +87,7 @@ following order: 1. Shape of input tensor. Use `-1` for dynamic dimensions 2. Dtype of the input tensor -3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h#L54-L67). This +3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h#L54-L67). This will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the IR to eventually have value semantics. diff --git a/docs/architecture.md b/docs/architecture.md index e503ba40d93b..8ee6bfda8a0a 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -55,14 +55,14 @@ factored such that we can handle this with one core import path, which is through the PyTorch "[JIT IR](https://github.com/pytorch/pytorch/blob/78c8a0d75220bdd4955415b5f81509e005af4232/torch/csrc/jit/OVERVIEW.md)", and lives in -[torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir). +[torch-mlir/python/torch_mlir/jit_ir_importer](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir). The JIT IR is a highly principled IR that faithfully models a Python subset (+ tensors, the PyTorch op registry, and a few other things). All the other PyTorch program representations can eventually bottom-out on the JIT IR via some path provided by PyTorch. The `torch` dialect is almost entirely in 1:1 correspondence with the JIT IR -- this allows the importer to be extremely small (the core is -[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp#L1)). +[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/jit_ir_importer/csrc/node_importer.cpp#L1)). ### Ops @@ -70,7 +70,7 @@ See [TorchOps.td](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83 The ops in the `torch` dialect are almost entirely generated based on the PyTorch JIT IR operator registry via the script -[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)). +[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)). This script queries the registry and generates MLIR [ODS](https://mlir.llvm.org/docs/OpDefinitions/) in [GeneratedTorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L1). We have a guide for [adding a new op end-to-end](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation). @@ -195,7 +195,7 @@ values. When one `torch.jit.script`'s a `torch.nn.Module`, the result is actually an `IValue` that represents the module, with a hierarchy of children `IValue`'s. Strictly speaking, JIT IR `torch::jit::Graph`'s are only used to represent the bodies of methods on the modules. So in addition to importing the -JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp#L1). +JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.cpp#L1). Most of the IValue modeling can reuse `torch` dialect ops that already exist otherwise, such as `torch.constant.int` to represent an int in the object graph. diff --git a/docs/development.md b/docs/development.md index d2b86504bf28..c60312e7ac5e 100644 --- a/docs/development.md +++ b/docs/development.md @@ -20,6 +20,7 @@ source mlir_venv/bin/activate python -m pip install --upgrade pip # Install latest PyTorch nightlies and build requirements. python -m pip install -r requirements.txt +python -m pip install -r torchvision-requirements.txt ``` ## CMake Build @@ -108,25 +109,25 @@ cmake --build build ### Linux and macOS ```shell -export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples +export PYTHONPATH=`pwd`/build/python_packages/torch_mlir:`pwd`/projects/pt1/examples ``` ### Windows PowerShell ```shell -$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/examples" +$env:PYTHONPATH = "$PWD/build/python_packages/torch_mlir;$PWD/projects/pt1/examples" ``` ## Testing MLIR output in various dialects -To test the compiler's output to the different MLIR dialects, you can use the example `examples/torchscript_resnet18_all_output_types.py`. +To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. Make sure you have activated the virtualenv and set the `PYTHONPATH` above (if running on Windows, modify the environment variable as shown above): ```shell source mlir_venv/bin/activate -export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples -python examples/torchscript_resnet18_all_output_types.py +export PYTHONPATH=`pwd`/build/tpython_packages/torch_mlir:`pwd`/projects/pt1/examples +python projects/pt1/examples/torchscript_resnet18_all_output_types.py ``` This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA. @@ -331,8 +332,8 @@ Torch-MLIR has two types of tests: 1. End-to-end execution tests. These compile and run a program and check the result against the expected output from execution on native Torch. These use a homegrown testing framework (see - `python/torch_mlir_e2e_test/torchscript/framework.py`) and the test suite - lives at `python/torch_mlir_e2e_test/test_suite/__init__.py`. + `projects/pt1/python/torch_mlir_e2e_test/framework.py`) and the test suite + lives at `projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py`. 2. Compiler and Python API unit tests. These use LLVM's `lit` testing framework. For example, these might involve using `torch-mlir-opt` to run a pass and diff --git a/docs/ltc_backend.md b/docs/ltc_backend.md index ae3cc887c7dc..b0177542899b 100644 --- a/docs/ltc_backend.md +++ b/docs/ltc_backend.md @@ -12,7 +12,7 @@ [Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR. After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation. -LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR. +LTC support is provided through an abstract [`TorchMlirBackendImpl`](../projects/ltc/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR. Implementations based on this abstract class will be able to specify their own compile and execution workflows. Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend). @@ -27,7 +27,7 @@ View examples [here](ltc_examples.md). - The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td), excluding those explicitly blacklisted in the YAML file -### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated)) +### Autogen Files ([`projects/ltc/csrc/base_lazy_backend/generated`](../projects/ltc/csrc/base_lazy_backend/generated)) Generated files are created in this directory, which is ignored by version control. - `LazyIr.h` @@ -41,7 +41,7 @@ Generated files are created in this directory, which is ignored by version contr - `shape_inference.{cpp,h}` - Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions -### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend)) +### Base Backend ([`projects/ltc/csrc/base_lazy_backend`](../projects/ltc/csrc/base_lazy_backend)) - `backend_impl.{cpp,h}` - Base LTC backend to setup Torch-MLIR lowering context diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index d6552314999b..c2e757f7a0ff 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(TorchOnnxToTorch) + set(LLVM_TARGET_DEFINITIONS Passes.td) if(TORCH_MLIR_ENABLE_STABLEHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt b/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt new file mode 100644 index 000000000000..a58ce5bf9b7d --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionTorchOnnxToTorchPassIncGen) +add_mlir_doc(Passes TorchMLIRConversionTorchOnnxToTorchPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h new file mode 100644 index 000000000000..6eea35c9d255 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h @@ -0,0 +1,27 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir::torch::onnx_c { + +std::unique_ptr> createTorchOnnxToTorchPass(); + +/// Registers all torch-mlir conversion passes. +void registerTorchOnnxToTorchPasses(); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td new file mode 100644 index 000000000000..b92649d025a6 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"> { + let summary = "Converts ONNX custom ops in the torch dialect to native torch ops"; + let description = [{ + Converts equivalent ONNX custom ops to built-in equivalents. + + See the README for a detailed description of how this operates. + }]; + + let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()"; +} + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h new file mode 100644 index 000000000000..5b144503c0ec --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -0,0 +1,195 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H +#define TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir::torch::onnx_c { + +/// Used during ONNX pattern matching to bind common patterns of operands, +/// result types and attributes to local variables in a way that is easy +/// to fail the pattern if constraints are violated. Most methods return +/// a ParseResult, which allows for chaining like: +/// +/// if (binder.tensorOperand(foo) || binder.tensorResultType(t)) +/// return failure(); +struct OpBinder { + OpBinder(Operation *op) : op(op) {} + + Location getLoc() { return op->getLoc(); } + + // Operand matches of different arities. + ParseResult tensorOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + if (!toValidTensorType(value0.getType())) + return failure(); + return success(); + } + + ParseResult tensorOperands(Value &value0, Value &value1) { + if (op->getNumOperands() != 2) + return failure(); + value0 = op->getOperand(0); + value1 = op->getOperand(1); + if (!toValidTensorType(value0.getType()) || + !toValidTensorType(value1.getType())) + return failure(); + return success(); + } + + ParseResult tensorOperandAtIndex(Value &valueIdx, int64_t idx) { + if (idx >= op->getNumOperands()) + return failure(); + valueIdx = op->getOperand(idx); + if (!toValidTensorType(valueIdx.getType())) + return failure(); + return success(); + } + + // Result type matchers of different arities. + ParseResult tensorResultType(Torch::ValueTensorType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto t = toValidTensorType(op->getResult(0).getType()); + if (!t) + return failure(); + type0 = t; + return success(); + } + + // Attribute accessors. + ParseResult s64BoolAttr(bool &value, StringRef nameSuffix, + bool defaultValue = false) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto integerAttr = dyn_cast(attr)) { + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + value = static_cast(integerAttr.getSInt()); + return success(); + } + return failure(); + } + + ParseResult s64IntegerAttr(int64_t &value, StringRef nameSuffix, + int64_t defaultValue = 0) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto integerAttr = dyn_cast(attr)) { + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + value = integerAttr.getSInt(); + return success(); + } + return failure(); + } + + ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, + std::string defaultValue = "") { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto stringAttr = dyn_cast(attr)) { + value = stringAttr.str(); + return success(); + } + return failure(); + } + + Torch::ValueTensorType toValidTensorType(Type t) { + auto tt = dyn_cast(t); + if (tt && tt.hasSizes()) + return tt; + return {}; + } + + Operation *op; +}; + +/// We use a single pattern per ONNX domain to handle all named custom +/// ops. +/// This allows us to avoid the n^2 problem on pattern application by +/// implementing a secondary index based on the name and sinceVersion +/// attributes. +/// It also lets us add some ergonomics for trivial cases. +class OnnxCustomOpConversionPattern + : public OpConversionPattern { +public: + using HandlerFn = LogicalResult (*)(OpBinder binder, + ConversionPatternRewriter &rewriter); + struct HandlerReg { + HandlerReg(HandlerFn callback, int64_t sinceVersion) + : callback(callback), sinceVersion(sinceVersion) {} + HandlerFn callback; + int64_t sinceVersion; + }; + + OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix, + int64_t domainVersion) + : OpConversionPattern(context), domainPrefix(std::move(domainPrefix)), + domainVersion(domainVersion) {} + + LogicalResult + matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + + /// Adds all fully qualified operator names to the given set. + /// This is typically used for implementing a dynamic legality + /// check for torch.operator names. + void populateLegalizedNames(DenseSet &legalizedNames); + + /// Register a conversion for a specific ONNX operator. For the + /// default domain, this is the canonical ONNX operator name (i.e. + /// "Acos"). + /// Multiple conversions can be registered for the same op, most + /// commonly differing by their `sinceVersion`. + void onOp(StringRef name, int64_t sinceVersion, HandlerFn callback); + +private: + std::string domainPrefix; + int64_t domainVersion; + DenseMap> namedHandlers; +}; + +// Patterns are split into chunks to speed compile time and reduce some +// contention on the same source files. +void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns); +void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns); +void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md b/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md new file mode 100644 index 000000000000..6de1cc923411 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md @@ -0,0 +1,133 @@ +# TorchOnnx To Torch Conversions + +We enable the direct representation of many ONNX features directly in +the `torch` dialect as `torch.operator` custom ops with names like +`onnx.{OperatorName}`. The majority of ONNX operators are represented +with a systematic transformation. See +[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) +for the reference importer which complies with the rules below +(this is planned to be upstreamed to torch-mlir proper in the near +future). + +## Adding new ONNX operators + +With the exception of certain special or complicated ONNX operators, most +are relatively straight-forward to map, following this general procedure: + +* Plan the ops you wish to support by consulting the + [ONNX operator database](https://onnx.ai/onnx/operators/). + * This database has detailed diffs wrt different support versions but + at the level of detail we operate, most version diffs are inconsequential + and just require a bit more pattern support. + * This typically applies to generalization of broadcasting semantics, + expanded type support, and other things of the like. +* *Prerequisite*: Add support for the op to torch-mlir if it does not + already exist. +* Open the corresponding implementation file `DefaultDomainXtoY.cpp` + corresponding with the alphabetic sort of the op and add a conversion. +* Generate successful test cases: + * Either run the Turbine importer to produce MLIR output for all + ops/models in the ONNX test suite or use a dump that someone has + generated: + * [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) + * There are often many variants of tests for checking conformance of + different historic ONNX encodings, but these are often not load bearing + at the MLIR level. + * Pick a handful of test cases and add them to + `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an + alphabetic breakdown. At this time, ignore tests that are not exercising + useful differences in the pattern implementations. +* Generate failure test cases: + * Some ops have forms that do not (easily) map to torch-mlir. If you leave + an op under-implemented, add a failing test case to + `test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`. +* Optional but recommended: Use your test case files to fuzz against the + torch-mlir backend of your choice by running a backend conversion pipeline + and fixing any crashes/issues. +* Send a patch with your changes. + +## ONNX proto to `torch` dialect mapping + +### Type Conversion + +* Tensors: ONNX tensor types are converted to `torch.vtensor` + with static and dynamic dimensions. We require that shape + inference has run to produce ranked tensors. +* Tensor element types are directly converted to corresponding + MLIR types as used by the rest of torch-mlir. +* String, sequence and sparse tensor types are presently not mapped. + +### Attributes + +A subset of attributes types are converted directly to an attribute +dict on the op with a name like `torch.onnx.{AttributeName}`. The +following attribute type mappings are made: + +* `FLOAT`: `FloatAttr` +* `INT`: Signed `IntegerAttr` of width 64 +* `STRING`: `StringAttr` +* `TENSOR`: Converted to one of: + * `DenseResourceElementsAttr` for inlined `raw_data` + * `DenseElementsAttr` for splats + * `DenseElementsAttr` for inlined typed proto initialization +* `FLOATS`: `ArrayAttr` of `FloatAttr` +* `INTS`: `ArrayAttr` of signed `IntegerAttr` of width 64 +* `STRINGS`: `ArrayAttr` of `StringAttr` +* `TENSORS`: `ArrayAttr` of corresponding `TENSOR` conversion + +The following attribute types have no present, systematic conversion. +Their presence on an op indicates that the op is a special form, which +must be handled specially: + +* `GRAPH` +* `SPARSE_TENSOR` (TBD: it is possible to handle this systematically if + useful). +* `TYPE_PROTO` (TBD: it may be possible to handle this systematically if + useful). +* Plural equivalents of the above. + +### Default operation conversion + +Operations are converted to a `torch.operator` with name `onnx.{OperatorName}`. +The constraint that the ONNX graph is topologically sorted and free of +cycles matches the SSA form. Operands and results are mapped directly. + +This conversion only applies to the default (empty) domain. + +### Quantization information + +Quantization parameters are carried out of line in the ONNX protobuf +and will be repatriated upon import to torch. The exact mechanism is +not yet implemented. + +### Version and metadata + +The `IsolatedFromAbove` parent of the ops can contain the following +metadata: + +* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to + `ModelProto.ir_version`. +* `torch.onnx_meta.producer_name`: `StringAttr` corresponding to + `ModelProto.producer_name`. +* `torch.onnx_meta.producer_version`: `StringAttr` corresponding to + `ModelProto.producer_version`. +* `torch.onnx_meta.opset_version`: 64bit `IntegerAttr` corresponding + to `ModelProto.opset_import.version` for the domain "" (empty). + Will be ommitted if the default opset is not included. +* `torch.onnx_meta.opset_versions`: DictAttr of 64bit `IntegerAttr` + for each non default domain. + +Generally, the importer handles variations in `ir_version` whereas +the transformations here handle opset version differences. Version +independent transformations are encouraged where possible if there +are only minor variations of an op. Major variations should use +`since_version` sensitive patterns. + +### Special op forms + +Certain ONNX operators map to different structural components of +torch-mlir's representation: + +* `ConstantOfShape`: Mapped to `torch.vtensor.literal` with + a corresponding `value` attribute. + diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h similarity index 93% rename from lib/Conversion/TorchToLinalg/Utils.h rename to include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 3bee8d642533..134fbeca46dc 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" namespace mlir { namespace torch { @@ -88,6 +89,12 @@ Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); Value convertTensorToElementType(OpBuilder &b, Location loc, Value tensor, Type elementType); +// Convert a scalar type to the corresponding builtin type in the +// linalg-on-tensors backend. +FailureOr +getBackendTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0ff3f3045faa..c000411e8e44 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13,7 +13,7 @@ // This file is automatically generated. Please do not edit. // Generated via: // ``` -// python -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen +// build_tools/update_torch_ods.sh // ``` // //===----------------------------------------------------------------------===// @@ -886,6 +886,51 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ }]; } +def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::acos : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcosOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcosOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::acos_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcos_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcos_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ AllowsTypeRefinement, HasValueSemantics, @@ -1023,51 +1068,6 @@ def Torch_AtenAsin_Op : Torch_Op<"aten.asin_", [ }]; } -def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::acos : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenAcosOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenAcosOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::acos_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenAcos_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenAcos_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenNegOp : Torch_Op<"aten.neg", [ AllowsTypeRefinement, HasValueSemantics, @@ -2192,55 +2192,6 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ }]; } -def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$mask, - AnyTorchTensorType:$value - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$mask, - Torch_NonValueTensorType:$value - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ AllowsTypeRefinement, HasValueSemantics, @@ -3030,6 +2981,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ }]; } +def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseLeftShift_TensorOp : Torch_Op<"aten.bitwise_left_shift_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_left_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseLeftShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseLeftShift_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -3748,6 +3746,56 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ }]; } +def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$mask, + Torch_NonValueTensorType:$value + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ AllowsTypeRefinement, HasValueSemantics, @@ -8489,6 +8537,31 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [ }]; } +def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`"; + let arguments = (ins + Torch_StringType:$equation, + AnyTorchListOfTensorType:$tensors, + AnyTorchOptionalListOfTorchIntType:$path + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenEinsumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -11700,7 +11773,6 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ @@ -14407,6 +14479,30 @@ def Torch_PrimsCollapseOp : Torch_Op<"prims.collapse", [ }]; } +def Torch_PrimsSplitDimOp : Torch_Op<"prims.split_dim", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `prims::split_dim : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$a, + Torch_IntType:$dim, + Torch_IntType:$outer_length + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsSplitDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void PrimsSplitDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [ AllowsTypeRefinement, ReadOnly diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 313050d3b69c..842c86defb74 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -40,9 +40,8 @@ TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, std::optional matchLegalConstantIndexIntoListOfSize(Value v, int64_t length); torch_upstream::ScalarType getScalarTypeForType(Type type); -FailureOr getTypeForScalarType( - MLIRContext *context, torch_upstream::ScalarType dtypeInt, - mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); +FailureOr getTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt); Type getTypeForTorchType( MLIRContext *context, Type type, diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 8956066b8769..d9030c23a66f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -14,16 +14,19 @@ set(LinkedLibs MLIRTosaDialect MLIRSupport - TorchMLIRTorchPasses - TorchMLIRTorchConversionDialect - + # Dialects. + TorchMLIRTMTensorDialect TorchMLIRTorchDialect - TorchMLIRTorchConversionPasses + TorchMLIRTorchConversionDialect + # Dialect passes. TorchMLIRTMTensorPasses - TorchMLIRTMTensorDialect + TorchMLIRTorchConversionPasses + TorchMLIRTorchPasses + # Conversion passes. TorchMLIRConversionPasses + TorchMLIRTorchOnnxToTorch ) if(TORCH_MLIR_ENABLE_REFBACKEND) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index f26b4d6e895e..afbe775d3a20 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(TorchOnnxToTorch) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt new file mode 100644 index 000000000000..807db64eac64 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch + DefaultDomainAtoF.cpp + DefaultDomainGtoP.cpp + DefaultDomainQtoZ.cpp + Passes.cpp + Patterns.cpp + TorchOnnxToTorch.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch + + DEPENDS + TorchMLIRConversionTorchOnnxToTorchPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + TorchMLIRTorchDialect +) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp new file mode 100644 index 000000000000..44ced9eb4b64 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -0,0 +1,365 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainAtoF( + OnnxCustomOpConversionPattern &patterns) { + patterns.onOp("Abs", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + // TODO: Acosh unimplemented in torch-mlir + // Add became forward compatible with Torch in version 7. + patterns.onOp("Add", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + Value const1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs, const1); + return success(); + }); + // TODO: AffineGrid + patterns.onOp("And", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "ArgMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + bool keepDims; + int64_t axis; + bool selectLastIndex; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64BoolAttr(keepDims, "keepdims", true) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) + return failure(); + + if (selectLastIndex) { + // TODO: Figure out how to support this case. Need to add a reverse + // or something. + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: select_last_index=true"); + } + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(operand.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constKeepDims = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(keepDims)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAxis, constKeepDims); + return success(); + }); + patterns.onOp( + "ArgMin", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + bool keepDims; + int64_t axis; + bool selectLastIndex; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64BoolAttr(keepDims, "keepdims", true) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) + return failure(); + + if (selectLastIndex) { + // TODO: Figure out how to support this case. Need to add a reverse + // or something. + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: select_last_index=true"); + } + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(operand.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constKeepDims = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(keepDims)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAxis, constKeepDims); + return success(); + }); + // TODO: Asin unimplemented in torch-mlir + // TODO: Asinh unimplemented in torch-mlir + // TODO: Atanh unimplemented in torch-mlir + patterns.onOp("Atan", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Acos", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(direction, "direction", "")) + return failure(); + if (direction == "LEFT") { + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + } else { + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + } + return success(); + }); + patterns.onOp( + "BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("BitwiseNot", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "Cast", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t dtypeIntOnnx, dtypeIntTorch; + if (binder.tensorOperand(operand) || + binder.s64IntegerAttr(dtypeIntOnnx, "to") || + binder.tensorResultType(resultType)) + return failure(); + + // TODO: Add complete mapping. + switch (dtypeIntOnnx) { + case 1: + dtypeIntTorch = 6; // float + break; + case 10: + dtypeIntTorch = 5; // half + break; + case 11: + dtypeIntTorch = 7; // double + break; + case 16: + dtypeIntTorch = 15; // bfloat16 + break; + default: + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch)); + Value none = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + return success(); + }); + patterns.onOp("Ceil", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "Clip", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + if (binder.op->getNumOperands() == 1) { + Value source; + if (binder.tensorOperand(source) || + binder.tensorResultType(resultType)) + return failure(); + Value cstNone = + rewriter.create(binder.getLoc()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, source, /*min=*/cstNone, /*max=*/cstNone); + return success(); + } else if (binder.op->getNumOperands() == 2) { + Value source, min; + if (binder.tensorOperands(source, min) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, source, /*min=*/min); + return success(); + } else if (binder.op->getNumOperands() == 3) { + Value source, min, max; + if (binder.tensorOperandAtIndex(source, 0) || + binder.tensorOperandAtIndex(min, 1) || + binder.tensorOperandAtIndex(max, 2) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, source, min, max); + return success(); + } + return failure(); + }); + patterns.onOp("Cos", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Div", 14, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("Equal", 19, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("Floor", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); +} diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp new file mode 100644 index 000000000000..af4f06fdef77 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -0,0 +1,29 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainGtoP( + OnnxCustomOpConversionPattern &patterns) {} diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp new file mode 100644 index 000000000000..23af89f329ab --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -0,0 +1,29 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainQtoZ( + OnnxCustomOpConversionPattern &patterns) {} diff --git a/lib/Conversion/TorchOnnxToTorch/PassDetail.h b/lib/Conversion/TorchOnnxToTorch/PassDetail.h new file mode 100644 index 000000000000..bbcd3413c59c --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/PassDetail.h @@ -0,0 +1,24 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::torch::onnx_c { + +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H diff --git a/lib/Conversion/TorchOnnxToTorch/Passes.cpp b/lib/Conversion/TorchOnnxToTorch/Passes.cpp new file mode 100644 index 000000000000..1f8cb05fa02c --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Passes.cpp @@ -0,0 +1,19 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" +} // end namespace + +void mlir::torch::onnx_c::registerTorchOnnxToTorchPasses() { + ::registerPasses(); +} diff --git a/lib/Conversion/TorchOnnxToTorch/Patterns.cpp b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp new file mode 100644 index 000000000000..6ca7824165d3 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp @@ -0,0 +1,57 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::dbgs; +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +#define DEBUG_TYPE "torch-onnx" + +LogicalResult OnnxCustomOpConversionPattern::matchAndRewrite( + Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto foundIt = namedHandlers.find(op.getNameAttr()); + if (foundIt == namedHandlers.end()) + return failure(); + auto ®gies = foundIt->second; + for (const HandlerReg ® : reggies) { + if (domainVersion < reg.sinceVersion) { + LLVM_DEBUG(dbgs() << ": skipping conversion " << foundIt->first + << ", sinceVersion=" << reg.sinceVersion + << ", for domainVersion=" << domainVersion << "\n"); + continue; + } + if (succeeded(reg.callback(OpBinder(op), rewriter))) { + return success(); + } else { + LLVM_DEBUG(dbgs() << ": conversion failed to apply: " << foundIt->first + << ", sinceVersion=" << reg.sinceVersion << "\n"); + } + } + return rewriter.notifyMatchFailure(op, "no matching versioned converter"); +} + +void OnnxCustomOpConversionPattern::populateLegalizedNames( + DenseSet &legalizedNames) { + for (auto it : namedHandlers) + legalizedNames.insert(it.first); +} + +void OnnxCustomOpConversionPattern::onOp(StringRef name, int64_t sinceVersion, + HandlerFn callback) { + SmallString<64> fullName(domainPrefix); + fullName.append(name); + StringAttr nameAttr = StringAttr::get(getContext(), fullName); + namedHandlers[nameAttr].push_back(HandlerReg(callback, sinceVersion)); +} diff --git a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp new file mode 100644 index 000000000000..ea890bf0f4b6 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp @@ -0,0 +1,87 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "./PassDetail.h" +#include "mlir/Support/LLVM.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::dbgs; +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +#define DEBUG_TYPE "torch-onnx" + +namespace { + +int64_t getDefaultOpsetVersion(Operation *containerOp) { + auto attr = + containerOp->getAttrOfType("torch.onnx_meta.opset_version"); + if (!attr) + return 0; + if (auto type = dyn_cast(attr.getType())) { + if (!type || !type.isSigned()) + return 0; + } + return attr.getSInt(); +} + +class ConvertTorchOnnxToTorch + : public ConvertTorchOnnxToTorchBase { +public: + ConvertTorchOnnxToTorch() = default; + void runOnOperation() override { + MLIRContext *context = &getContext(); + + // Populate our patterns for each handled domain. + int64_t defaultOpsetVersion = getDefaultOpsetVersion(getOperation()); + if (defaultOpsetVersion == 0) { + emitError(getOperation().getLoc()) + << "function is missing onnx opset version attribute " + "(torch.onnx_meta.opset_version)"; + return signalPassFailure(); + } + + auto defaultDomainPatterns = + std::make_unique( + context, "onnx.", + /*domainVersion=*/defaultOpsetVersion); + populateDefaultDomainAtoF(*defaultDomainPatterns); + populateDefaultDomainGtoP(*defaultDomainPatterns); + populateDefaultDomainQtoZ(*defaultDomainPatterns); + + // Ask each domain for its handled names and configure the + // conversion target. + ConversionTarget target(*context); + DenseSet legalizedNames; + defaultDomainPatterns->populateLegalizedNames(legalizedNames); + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](Torch::OperatorOp op) { + return !legalizedNames.contains(op.getNameAttr()); + }); + + RewritePatternSet patterns(context); + patterns.insert(std::move(defaultDomainPatterns)); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::onnx_c::createTorchOnnxToTorchPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 26c7e1d01cf1..4eb02215a8bf 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -15,13 +15,13 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -362,7 +362,7 @@ class ConvertAtenViewOp : public OpConversionPattern { auto [inputShape, outputShape] = getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); - + // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. @@ -380,8 +380,8 @@ class ConvertAtenViewOp : public OpConversionPattern { bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1; bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1; bool singleDynDimsAreEqual = - inputHasOneDynDim && outputHasOneDynDim && - productReduce(inputShape) == productReduce(outputShape); + inputHasOneDynDim && outputHasOneDynDim && + productReduce(inputShape) == productReduce(outputShape); SmallVector> unchangedDims; for (auto [outputDim, outputDimSize] : llvm::enumerate(outputSizeTorchInt)) { @@ -533,6 +533,10 @@ class ConvertAtenViewOp : public OpConversionPattern { } } + auto cast = [&](Location loc, Type t, Value v) -> Value { + return rewriter.createOrFold(loc, t, v); + }; + // Check if the shapes already match up to dynamic sizes. If so, we can just // cast as the result type because the previous loop sets up the necessary // dim checks in case of dynamic sizes. @@ -542,7 +546,9 @@ class ConvertAtenViewOp : public OpConversionPattern { llvm::all_of(outputAssociations, [](ReassociationIndices indices) { return indices.size() == 1; })) { - rewriter.replaceOpWithNewOp(op, resultType, input); + + auto castResult = cast(loc, resultType, input); + rewriter.replaceOp(op, castResult); return success(); } @@ -551,8 +557,7 @@ class ConvertAtenViewOp : public OpConversionPattern { makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( makeShapeLLVMCompatible(inputShape), resultType.getElementType()); - Value castedInput = - rewriter.create(loc, adjustedInputType, input); + Value castedInput = cast(loc, adjustedInputType, input); std::optional expandedInput; std::optional collapsedInput; @@ -602,7 +607,8 @@ class ConvertAtenViewOp : public OpConversionPattern { Value result = collapsedInput.has_value() ? collapsedInput.value() : expandedInput.value(); - rewriter.replaceOpWithNewOp(op, resultType, result); + auto castResult = cast(loc, resultType, result); + rewriter.replaceOp(op, castResult); return success(); } @@ -1154,7 +1160,8 @@ class ConvertAtenContiguousOp : public OpConversionPattern { return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); return success(); } }; @@ -1407,7 +1414,8 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { SmallVector indexingMaps{inputMap, outputMap}; - SmallVector iteratorTypes(resultType.getRank(), utils::IteratorType::parallel); + SmallVector iteratorTypes( + resultType.getRank(), utils::IteratorType::parallel); Value constantZero = getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); @@ -1417,7 +1425,6 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { loc, outTensor.getType(), input, outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value realVal = b.create(loc, elementType, args[0]); Value imagVal = diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 0e89d822669f..277341bea874 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -11,13 +11,13 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index bbf53162d6a1..b263786c3dbb 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -11,12 +11,12 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -51,12 +51,24 @@ class ConvertAtenMmOp : public OpConversionPattern { // The compiler cannot crash even if the user wrote an erroneous program! if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - if (lhs.getType().cast().getRank() != 2 || - rhs.getType().cast().getRank() != 2) { + + RankedTensorType lhsType = lhs.getType().cast(); + RankedTensorType rhsType = rhs.getType().cast(); + + if (lhsType.getRank() != 2 || rhsType.getRank() != 2) { return rewriter.notifyMatchFailure( op, "expected both operands to aten.mm to be rank 2"); } + ValueTensorType lhsTorchType = + op.getSelf().getType().cast(); + ValueTensorType rhsTorchType = + op.getMat2().getType().cast(); + if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.mm with different input element types"); + } + Value lhsDim0 = rewriter.create(loc, lhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); @@ -73,16 +85,22 @@ class ConvertAtenMmOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); - Value initTensor = rewriter.create( - loc, ArrayRef{lhsDim0, rhsDim1}, elementType); - Value c0 = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); - Value zeroFill = - rewriter.create(loc, c0, initTensor).getResult(0); - Value matmul = rewriter - .create(loc, zeroFill.getType(), - ValueRange{lhs, rhs}, zeroFill) - .getResult(0); + Value zeroFill = createZeroInitTensor( + rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); + + Value matmul; + auto intType = dyn_cast(lhsTorchType.getDtype()); + if (intType && intType.isUnsigned()) { + matmul = rewriter + .create( + loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill) + .getResult(0); + } else { + matmul = rewriter + .create(loc, zeroFill.getType(), + ValueRange{lhs, rhs}, zeroFill) + .getResult(0); + } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result @@ -830,6 +848,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { indices); }; + // expand F,C,H,W -> G,F/G,C,H,W auto expandWeight = [&](Value tensor) { auto inType = tensor.getType().cast(); auto inShape = makeShapeTorchCompatible(inType.getShape()); @@ -850,21 +869,19 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value paddedInputExpanded = expandGroups(paddedInput, 1); Value weightExpanded = expandWeight(weight); - Value outputTensorExpanded = expandGroups(outputTensor, 1); + auto expandOutputTensor = expandGroups(outputTensor, 1); // TODO: add 1D and 3D case conv = rewriter - .create( - loc, outputTensorExpanded.getType(), + .create( + loc, expandOutputTensor.getResultType(), ValueRange{paddedInputExpanded, weightExpanded}, - outputTensorExpanded, stridesAttr, dilationAttr) + expandOutputTensor.getResult(), stridesAttr, dilationAttr) .getResult(0); - SmallVector indices{{0}, {1, 2}}; - for (auto dim = 3; dim <= (int64_t)inRank; dim++) - indices.push_back({dim}); conv = rewriter.create( - loc, outputTensor.getType(), conv, indices); + loc, outputTensor.getType(), conv, + expandOutputTensor.getReassociationIndices()); } Type newResultType = getTypeConverter()->convertType(op.getType()); diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 1d7ff925b6ed..87419f0935ab 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -11,12 +11,12 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index e1a3e416c460..26a2c0ea551a 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -11,12 +11,12 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 641f1ef8cc1c..289851cd3d27 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -11,13 +11,13 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -30,70 +30,80 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { -// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic -// op, producing two output buffers. +// Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an +// linalg.indexed_generic op, producing two output buffers. // -// The first output buffer contains the maximum value found. It is initialized -// to the minimum representable value of the input element type. +// The first output buffer contains the maximum (minium) value found. It is +// initialized to the minimum (maximum) representable value of the input +// element type. // -// The second output buffer contains the index of the found maximum value. It is -// initialized to 0 and is resulting integer type. +// The second output buffer contains the index of the found maximum (minimum) +// value. It is initialized to 0 and is resulting integer type. // -// The indexed_generic op updates both the maximum value and index if the -// current value exceeds the running max. -class ConvertAtenMaxDimOp : public OpConversionPattern { +// The indexed_generic op updates both the maximum (minimum) value and index +// if the current value exceeds the running max (min). +template +class ConvertAtenMinMaxDimOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + + using OpAdaptor = typename OpTy::Adaptor; LogicalResult - matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor, + matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + static_assert(std::is_same() || + std::is_same()); + constexpr bool isMax = std::is_same(); + const llvm::StringRef opName = op->getName().getStringRef(); - Location loc = maxDimOp.getLoc(); + Location loc = op.getLoc(); Value input = adaptor.getSelf(); RankedTensorType valResultType = getTypeConverter() - ->convertType(maxDimOp.getResult(0).getType()) - .cast(); + ->convertType(op.getResult(0).getType()) + .template cast(); + RankedTensorType idxResultType = - getTypeConverter() - ->convertType(maxDimOp.getResult(1).getType()) - .cast(); - RankedTensorType inputType = input.getType().cast(); + this->getTypeConverter() + ->convertType(op.getResult(1).getType()) + .template cast(); + RankedTensorType inputType = + input.getType().template cast(); Type idxElementType = idxResultType.getElementType(); if (!idxElementType.isa()) return rewriter.notifyMatchFailure( - maxDimOp, - "aten.max_dim to linalg.* requires integer-like result type"); + op, opName + " to linalg.* requires integer-like result type"); bool keepDim = false; - if (!matchPattern(maxDimOp.getKeepdim(), m_TorchConstantBool(&keepDim))) + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) return rewriter.notifyMatchFailure( - maxDimOp, "aten.max_dim requires boolean value for keepdim"); + op, opName + " requires boolean value for keepdim"); int64_t dim; - if (!matchPattern(maxDimOp.getDim(), m_TorchConstantInt(&dim))) + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( - maxDimOp, "aten.max_dim to linalg.* requires int value for Dim"); + op, opName + " to linalg.* requires int value for Dim"); dim = toPositiveDim(dim, inputType.getRank()); if (!isValidDim(dim, inputType.getRank())) - return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim"); + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); Type inElementType = inputType.getElementType(); if (!inElementType.isa()) { if (inElementType.isa()) { - auto integerTy = maxDimOp.getSelf() + auto integerTy = op.getSelf() .getType() - .cast() + .template cast() .getDtype() - .dyn_cast(); + .template dyn_cast(); if (integerTy.isUnsigned()) return rewriter.notifyMatchFailure( - maxDimOp, "aten.max_dim to linalg.* requires input element type " + op, opName + " to linalg.* requires input element type " "to be signed in case of integer"); } else { return rewriter.notifyMatchFailure( - maxDimOp, "aten.max_dim to linalg.* requires Float or Integer " + op, opName + " to linalg.* requires Float or Integer " "input element type"); } } @@ -112,29 +122,29 @@ class ConvertAtenMaxDimOp : public OpConversionPattern { Value filledTensorIdx = createZeroInitTensor(rewriter, loc, resultShape, idxElementType); - // Second fill the output buffer for the running max. - Value initTensorMax = rewriter.create( + // Second fill the output buffer for the running max or min. + Value initTensorVal = rewriter.create( loc, getAsOpFoldResult(resultShape), inElementType); - Value fillValueMax; + Value fillValue; if (inElementType.isa()) { - fillValueMax = rewriter.create( + fillValue = rewriter.create( loc, rewriter.getFloatAttr( inElementType, APFloat::getInf( inElementType.cast().getFloatSemantics(), - /*Negative=*/true))); + /*Negative=*/isMax))); } else { - fillValueMax = rewriter.create( - loc, rewriter.getIntegerAttr( - inElementType, - APSInt::getSignedMinValue( - inElementType.cast().getWidth()))); + auto width = inElementType.cast().getWidth(); + auto init = isMax ? APSInt::getSignedMinValue(width) + : APSInt::getSignedMaxValue(width); + fillValue = rewriter.create( + loc, rewriter.getIntegerAttr(inElementType, init)); } - Value filledTensorMax = - rewriter.create(loc, fillValueMax, initTensorMax) + Value filledTensorVal = + rewriter.create(loc, fillValue, initTensorVal) .result(); // Create the affine expressions that will be used to @@ -161,8 +171,8 @@ class ConvertAtenMaxDimOp : public OpConversionPattern { auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}); auto linalgOp = rewriter.create( loc, - ArrayRef({filledTensorMax.getType(), filledTensorIdx.getType()}), - input, ValueRange({filledTensorMax, filledTensorIdx}), maps, + ArrayRef({filledTensorVal.getType(), filledTensorIdx.getType()}), + input, ValueRange({filledTensorVal, filledTensorIdx}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { @@ -174,33 +184,51 @@ class ConvertAtenMaxDimOp : public OpConversionPattern { nestedLoc, oldIndex.getType(), rewriter.create(loc, dim)); - Value resultMax, predicate; + Value resultVal, predicate; if (inElementType.isa()) { - resultMax = rewriter.create(nestedLoc, newValue, - oldValue); - predicate = rewriter.create( - nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + arith::CmpFPredicate predType; + if (isMax) { + predType = arith::CmpFPredicate::OGT; + resultVal = rewriter.create( + nestedLoc, newValue, oldValue); + } else { + predType = arith::CmpFPredicate::OLT; + resultVal = rewriter.create( + nestedLoc, newValue, oldValue); + } + + predicate = rewriter.create(nestedLoc, predType, + newValue, oldValue); } else { - resultMax = - rewriter.create(nestedLoc, newValue, oldValue); - predicate = rewriter.create( - nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); + arith::CmpIPredicate predType; + if (isMax) { + predType = arith::CmpIPredicate::sgt; + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } else { + predType = arith::CmpIPredicate::slt; + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } + predicate = rewriter.create(nestedLoc, predType, + newValue, oldValue); } auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); nestedBuilder.create( - nestedLoc, ValueRange({resultMax, resultIndex})); + nestedLoc, ValueRange({resultVal, resultIndex})); }); // This cast is required to fix the shape in the case of keepDim=True - Value maxValuesCast = rewriter.create( + Value valuesCast = rewriter.create( loc, valResultType, linalgOp.getResult(0)); - Value maxIdxCast = rewriter.create(loc, idxResultType, - linalgOp.getResult(1)); - rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast}); + Value idxCast = rewriter.create(loc, idxResultType, + linalgOp.getResult(1)); + rewriter.replaceOp(op, {valuesCast, idxCast}); return success(); } }; + } // namespace static Value createInitElementForReduceOp(OpBuilder &b, Location loc, @@ -574,7 +602,9 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>(typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, context); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 7e73fabd8e9f..434b50b034dd 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -11,13 +11,13 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -127,9 +127,9 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + FailureOr maybeResultElementType = + torch_to_linalg::getBackendTypeForScalarType( + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); @@ -233,9 +233,9 @@ class ConvertAtenEmptyMemoryFormatOp if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + FailureOr maybeResultElementType = + torch_to_linalg::getBackendTypeForScalarType( + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index ee968daff010..2cc37a88313a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -11,7 +11,6 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -19,6 +18,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -57,7 +57,7 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); @@ -66,7 +66,7 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, static Value createGreaterThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); @@ -74,7 +74,7 @@ static Value createGreaterThanOrEqual(OpBuilder &b, Location loc, static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); @@ -82,7 +82,7 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, static Value createLessThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); @@ -272,6 +272,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.getMemoryFormat().getType().isa() && @@ -366,6 +370,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseLeftShiftTensor = + dyn_cast(op)) { + Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + bitwiseLeftShiftTensor.emitError( + "Bitwise_Left_Shift op does not support non-integer input dtype."); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); @@ -1001,6 +1019,58 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } return result; } + if (auto clampTensor = dyn_cast(op)) { + AtenClampTensorOp::Adaptor adaptor(operands); + auto min = adaptor.getMin(); + auto max = adaptor.getMax(); + if (min.getType().isa() || + max.getType().isa()) { + clampTensor.emitError("unimplemented: runtime optional type"); + return nullptr; + } + Type dtype = converter->convertType(clampTensor.getType()) + .cast() + .getElementType(); + bool isMinNone = true; + auto result = payloadArgs[0]; + if (!min.getType().isa()) { + isMinNone = false; + auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + Value pred; + if (dtype.isa()) { + pred = b.create(loc, arith::CmpFPredicate::ULT, result, + minPromoted); + } else if (dtype.isa()) { + pred = b.create(loc, arith::CmpIPredicate::slt, result, + minPromoted); + } else { + clampTensor.emitError( + "unimplemented: dtype other than float and integer " + "types are not supported."); + return nullptr; + } + result = b.create(loc, pred, minPromoted, result); + } + if (!max.getType().isa()) { + max = isMinNone ? payloadArgs[1] : payloadArgs[2]; + auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); + Value pred; + if (dtype.isa()) { + pred = b.create(loc, arith::CmpFPredicate::UGT, result, + maxPromoted); + } else if (dtype.isa()) { + pred = b.create(loc, arith::CmpIPredicate::sgt, result, + maxPromoted); + } else { + clampTensor.emitError( + "unimplemented: dtype other than float and integer " + "types are not supported."); + return nullptr; + } + result = b.create(loc, pred, maxPromoted, result); + } + return result; + } if (auto rsub = dyn_cast(op)) { Type dtype = converter->convertType(rsub.getType()) .cast() @@ -1043,9 +1113,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( atenToDtype.emitError("unimplemented: dtype must be a constant integer"); return nullptr; } - FailureOr maybeResultElementType = getTypeForScalarType( - atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + FailureOr maybeResultElementType = + torch_to_linalg::getBackendTypeForScalarType( + atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { atenToDtype.emitError("unable to convert `dtypeInt` to builtin type"); return nullptr; @@ -1246,22 +1316,24 @@ class ConvertElementwiseOp : public ConversionPattern { AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, - AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, - AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, - AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, - AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, + AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, + AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, + AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, + AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp, + AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, - AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, - AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, - AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, - AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, - AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op)) + AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, + AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, + AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, + AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1759,6 +1831,51 @@ class ConvertAtenDetachOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertPrimsSplitDimOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(PrimsSplitDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + auto aRankedTensorType = adaptor.getA().getType().cast(); + + const TypeConverter *typeConverter = getTypeConverter(); + + auto resultRankedTensorType = + typeConverter->convertType(op.getType()).cast(); + + // The dimension being split must be statically known. + + int64_t dimInt; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) + return failure(); + + SmallVector associations; + associations.reserve(aRankedTensorType.getRank()); + + for (unsigned i = 0; i < dimInt; ++i) { + associations.push_back(ReassociationIndices{i}); + } + associations.push_back(ReassociationIndices{dimInt, dimInt + 1}); + for (int i = dimInt + 2; i < resultRankedTensorType.getRank(); ++i) { + associations.push_back(ReassociationIndices{i}); + } + + auto expanded = rewriter.createOrFold( + op.getLoc(), resultRankedTensorType, adaptor.getA(), associations); + + rewriter.replaceOpWithNewOp(op, resultRankedTensorType, + expanded); + return success(); + } +}; +} // namespace + namespace { class ConvertPrimsCollapseOp : public OpConversionPattern { public: @@ -1850,21 +1967,22 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op, - AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, - AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, - AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, - AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, - AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, - AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, - AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, - AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, - AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp, - AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, - AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, - AtenImagOp>(); + AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, + AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, + AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, + AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, + AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenRealOp, AtenImagOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -1872,10 +1990,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); patterns.add(typeConverter, context); - + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index a666ca30b02f..ccc78985dc6c 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -7,8 +7,6 @@ // //===----------------------------------------------------------------------===// -#include "Utils.h" - #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -17,10 +15,10 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/Matchers.h" +#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; @@ -546,3 +544,18 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, return torch_to_linalg::createElementwiseLinalgGeneric( b, loc, {tensor}, elementType, dtypePromoteBody); } + +FailureOr torch_to_linalg::getBackendTypeForScalarType( + MLIRContext *context, torch_upstream::ScalarType dtypeInt) { + FailureOr maybeType = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + if (failed(maybeType)) { + return failure(); + } + Type type = *maybeType; + // The linalg-on-tensors backend currently expects integers to be signless. + if (auto intType = type.dyn_cast()) { + type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless); + } + return type; +} diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index a2a7cdab9da2..f0dc4aaf2dfa 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1664,13 +1664,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; + // The stablehlo backend expects signed integers to be signless. + if (resultElementType.isSignedInteger()) { + resultElementType = IntegerType::get( + op->getContext(), resultElementType.getIntOrFloatBitWidth(), + IntegerType::Signless); + } } // Create an uninitialized tensor of `resultSize` shape. diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5ddab7320c7a..4041a99d949b 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5601,8 +5601,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); FailureOr maybeResultElementType = getTypeForScalarType( - ctx, (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + ctx, (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c235f2694a79..e6b29ca98060 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -162,6 +162,42 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; } +static Value getScalarFloatValue(Value input, Location loc, + PatternRewriter &rewriter) { + auto inputType = input.getType(); + if (inputType.isa()) { + return input; + } + + auto inputTensorType = inputType.dyn_cast(); + if (!inputTensorType) + return nullptr; + + Type inputDtype = inputTensorType.getOptionalDtype(); + if (!inputDtype || + (!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64())) + return nullptr; + + std::optional inputRank = getTensorRank(input); + if (!inputRank || *inputRank != 0) + return nullptr; + + if (auto valueTensorLiteralOp = input.getDefiningOp()) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue() + .getValueAsDouble(); + return rewriter.create( + loc, rewriter.getF64FloatAttr(val)); + } else if (auto primNumToTensorScalarOp = + input.getDefiningOp()) { + return primNumToTensorScalarOp.getA(); + } else if (auto tensorFloatOp = input.getDefiningOp()) { + return tensorFloatOp.getT(); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // MethodOp //===----------------------------------------------------------------------===// @@ -1604,6 +1640,27 @@ OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenMaskedFillTensorOp +//===----------------------------------------------------------------------===// + +// Fold 0d fill tensor to scalar +void AtenMaskedFillTensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) { + auto scalarIntVal = + getScalarIntValue(op.getValue(), op->getLoc(), rewriter); + auto scalarFloatVal = + getScalarFloatValue(op.getValue(), op->getLoc(), rewriter); + if (!scalarIntVal && !scalarFloatVal) + return failure(); + Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal; + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getMask(), scalarVal); + return failure(); + }); +} + //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// @@ -2402,17 +2459,6 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { return list.getElements()[0]; } -//===----------------------------------------------------------------------===// -// AtenStackOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { - auto list = getOperand(0).getDefiningOp(); - if (!list || !list->hasOneUse() || list.getElements().size() != 1) - return nullptr; - return list.getElements()[0]; -} - //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===// @@ -2688,6 +2734,12 @@ void AtenBroadcastToOp::getCanonicalizationPatterns(RewritePatternSet &patterns, } } + if (selfShape.empty()) { + // Don't create view ops with input rank 0 because those are not supported + // in the linalg lowering. + return rewriter.notifyMatchFailure(op, "unimplemented: input rank 0 is not supported"); + } + // Create 1, ..., 1, inputShape[0], inputShape[1], inputShape[2] SmallVector reshapeShape = resultShape; for (unsigned i = 0; i < selfShape.size(); i++) diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index cee9705af24a..cf832b1b755e 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -200,7 +200,10 @@ static bool isValidTorchDtype(Type dtype) { } } if (type.isUnsigned()) { - return type.getWidth() == 8 || type.getWidth() == 4; + for (unsigned width : {4, 8, 16, 32, 64}) { + if (type.getWidth() == width) + return true; + } } } return false; diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4a4197b6837d..a02465399a9c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6227,7 +6227,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.testing_framework._convert_dtype_to_int(%arg0: !torch.int) -> !torch.int {\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.testing_framework._convert_dtype_to_int(%arg0: !torch.int) -> !torch.int {\n" " return %arg0 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.triu\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" @@ -6453,10 +6453,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.clamp.Tensor\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.clamp_min\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.clamp_min.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.clamp_max\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6543,6 +6551,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %7 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.split_dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: 'outer_length' must divide the size of the dimension, a[dim]\"\n" +" %str_0 = torch.constant.str \"AssertionError: 'outer_length' must be positive\"\n" +" %str_1 = torch.constant.str \"AssertionError: 'dim' must be less than the rank of the tensor\"\n" +" %none = torch.constant.none\n" +" %str_2 = torch.constant.str \"AssertionError: 'dim' must be non-negative\"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.ge.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.lt.int %arg1, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.remainder.int %4, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %arg1, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %15 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.append.t %7, %15 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %8 = torch.aten.append.t %7, %arg2 : !torch.list, !torch.int -> !torch.list\n" +" %9 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.floordiv.int %9, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %7, %10 : !torch.list, !torch.int -> !torch.list\n" +" %12 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %14 = torch.aten.__range_length %12, %13, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %14, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %15 = torch.aten.__derive_index %arg3, %12, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.append.t %7, %16 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %7 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6791,6 +6862,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n" @@ -6817,6 +6892,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" " return %2 : !torch.tuple, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.min.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" " %none = torch.constant.none\n" " %0 = torch.derefine %arg1 : !torch.list to !torch.optional>\n" @@ -7668,6 +7749,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7975,7 +8060,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -7990,12 +8075,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" " %int10 = torch.constant.int 10\n" " %int9 = torch.constant.int 9\n" " %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" @@ -8466,6 +8551,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -8475,7 +8564,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %true = torch.constant.bool true\n" " %false = torch.constant.bool false\n" " %int6 = torch.constant.int 6\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%arg0) : (!torch.int) -> !torch.bool\n" " %1 = torch.prim.If %0 -> (!torch.bool) {\n" " %4 = torch.aten.ne.int %arg0, %int6 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If.yield %4 : !torch.bool\n" @@ -8485,7 +8574,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = torch.prim.If %1 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%arg0) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %4 : !torch.bool\n" " }\n" " %3 = torch.prim.If %2 -> (!torch.int) {\n" @@ -8495,12 +8584,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" " %int7 = torch.constant.int 7\n" " %int6 = torch.constant.int 6\n" " %int15 = torch.constant.int 15\n" @@ -8585,7 +8674,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" @@ -8594,12 +8683,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" " %int4 = torch.constant.int 4\n" " %int3 = torch.constant.int 3\n" " %int2 = torch.constant.int 2\n" @@ -8620,7 +8709,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %true = torch.constant.bool true\n" " %0 = torch.prim.Uninitialized : !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" " torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" @@ -8650,7 +8739,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" @@ -8800,6 +8889,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" +" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" @@ -8812,6 +8913,42 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0 : (!torch.int) -> !torch.list>\n" +" %2 = torch.prim.ListConstruct %0#1 : (!torch.int) -> !torch.list\n" +" %3 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" %8 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.tuple\n" +" %9:2 = torch.prim.TupleUnpack %8 : !torch.tuple -> !torch.int, !torch.int\n" +" %10 = torch.aten.append.t %1, %9#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %11 = torch.aten.append.t %2, %9#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" %8 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.tuple\n" +" %9:2 = torch.prim.TupleUnpack %8 : !torch.tuple -> !torch.int, !torch.int\n" +" %10 = torch.aten.append.t %1, %9#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %11 = torch.aten.append.t %2, %9#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %1 : !torch.list> -> !torch.int\n" +" %6 = torch.aten.gt.int %5, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %2) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %8 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %7 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.clone\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -8841,7 +8978,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" " %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" @@ -8910,13 +9047,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" -" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" -" return %0 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.gelu\"(%arg0: !torch.tuple, %arg1: !torch.str) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -8932,7 +9065,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -8989,7 +9122,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -9004,7 +9137,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" @@ -9085,7 +9218,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" " %6 = torch.prim.If %5 -> (!torch.int) {\n" " torch.prim.If.yield %int4 : !torch.int\n" @@ -9230,7 +9363,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" @@ -9302,10 +9435,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.number) -> !torch.int {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.number) -> !torch.int {\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.number) -> !torch.int {\n" " %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.number -> !torch.tensor\n" " %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" " return %1 : !torch.int\n" @@ -9317,7 +9450,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" " %6 = torch.prim.If %5 -> (!torch.int) {\n" " torch.prim.If.yield %int4 : !torch.int\n" @@ -9447,10 +9580,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.number, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" @@ -9464,7 +9597,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int5 = torch.constant.int 5\n" " %0 = torch.prim.Uninitialized : !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %3 = torch.prim.If %2 -> (!torch.int) {\n" " torch.prim.If.yield %1#1 : !torch.int\n" " } else {\n" @@ -9480,7 +9613,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %9 = torch.prim.If %8 -> (!torch.int) {\n" " torch.prim.If.yield %int10 : !torch.int\n" " } else {\n" -" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %11 = torch.prim.If %10 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" @@ -9501,9 +9634,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9511,7 +9644,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.__or__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9519,7 +9652,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" @@ -9527,7 +9660,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9535,16 +9668,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9552,7 +9685,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9560,7 +9693,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9568,14 +9701,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" " %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" " %5 = torch.prim.If %4 -> (!torch.int) {\n" " torch.prim.If.yield %0#1 : !torch.int\n" @@ -9584,7 +9725,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %5 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%arg0: !torch.int) -> !torch.int {\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_priority_of_dtype(%arg0: !torch.int) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Cannot determine priority of dtype\"\n" " %int15 = torch.constant.int 15\n" @@ -9684,7 +9825,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" " torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" @@ -9692,7 +9833,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" " torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" @@ -9702,7 +9843,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" " %9 = torch.aten.ne.int %8, %int11 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %9 -> () {\n" " torch.prim.If.yield\n" @@ -9720,12 +9861,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" " %6 = torch.prim.If %5 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" " %9 = torch.prim.If %8 -> (!torch.bool) {\n" " %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If.yield %10 : !torch.bool\n" @@ -9764,12 +9905,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %5 = torch.prim.ListConstruct %4#0, %3#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %6 = torch.prim.ListConstruct %4#1, %3#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list>, !torch.list) -> !torch.int\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%7) : (!torch.int) -> !torch.bool\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list>, !torch.list) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%7) : (!torch.int) -> !torch.bool\n" " %9 = torch.prim.If %8 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %12 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" %12 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " %13 = torch.prim.If %12 -> (!torch.bool) {\n" " %14 = torch.aten.ne.int %7, %int6 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If.yield %14 : !torch.bool\n" @@ -9803,8 +9944,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" " %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" " %5 = torch.prim.If %4 -> (!torch.int) {\n" " torch.prim.If.yield %0#1 : !torch.int\n" @@ -9818,7 +9959,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9826,7 +9967,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9854,7 +9995,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } else {\n" " %7 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %8 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %9 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" " torch.prim.If.yield %9 : !torch.int\n" " }\n" " return %6 : !torch.int\n" @@ -9866,8 +10007,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" " %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" @@ -9882,7 +10023,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" @@ -9890,7 +10031,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" @@ -9898,7 +10039,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" @@ -9909,7 +10050,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %str_1 = torch.constant.str \"AssertionError: `grad_output` cannot be complex\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" " torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" @@ -9917,7 +10058,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" " torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" @@ -9927,7 +10068,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" " %9 = torch.prim.ListConstruct %int11 : (!torch.int) -> !torch.list\n" " %10 = torch.aten.__contains__.int_list %9, %8 : !torch.list, !torch.int -> !torch.bool\n" " %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" @@ -9953,7 +10094,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " %5 = torch.prim.If %4 -> (!torch.bool) {\n" " %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" @@ -9967,7 +10108,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" " %8 = torch.prim.If %7 -> (!torch.bool) {\n" " %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" @@ -9983,7 +10124,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" %11 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" " return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" @@ -10000,7 +10141,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " %5 = torch.prim.If %4 -> (!torch.bool) {\n" " %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" @@ -10014,7 +10155,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" " %8 = torch.prim.If %7 -> (!torch.bool) {\n" " %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" @@ -10030,7 +10171,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" %11 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" " return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" @@ -10058,7 +10199,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %false = torch.constant.bool false\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.bool) {\n" " %5 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If.yield %5 : !torch.bool\n" @@ -10093,7 +10234,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" " %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lerp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" @@ -10102,7 +10243,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" " %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" @@ -10135,7 +10276,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" " %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" " return %8 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" @@ -10145,8 +10286,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" " %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" " %7 = torch.prim.If %6 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -10158,27 +10299,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" @@ -10186,10 +10327,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" " %6 = torch.prim.If %5 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -10201,16 +10342,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -10219,27 +10360,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" @@ -10255,10 +10396,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n" " torch.prim.If %4 -> () {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" " torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" @@ -10271,7 +10412,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %5 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.elu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.number) -> !torch.int {\n" @@ -10293,7 +10434,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.Loop %int3, %true, init() {\n" " ^bb0(%arg4: !torch.int):\n" " %7 = torch.aten.__getitem__.t %3, %arg4 : !torch.list, !torch.int -> !torch.number\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%7) : (!torch.number) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%7) : (!torch.number) -> !torch.int\n" " %9 = torch.aten.append.t %2, %8 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" @@ -10302,13 +10443,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.Loop %5, %true, init() {\n" " ^bb0(%arg4: !torch.int):\n" " %7 = torch.aten.__getitem__.t %2, %arg4 : !torch.list, !torch.int -> !torch.int\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " %9 = torch.aten.append.t %4, %8 : !torch.list, !torch.bool -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" " %6 = torch.aten.any.bool %4 : !torch.list -> !torch.bool\n" " torch.prim.If %6 -> () {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" " torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" @@ -10326,9 +10467,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" @@ -10360,7 +10501,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %5 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %6 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list>, !torch.list) -> !torch.int\n" +" %7 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list>, !torch.list) -> !torch.int\n" " return %7 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" @@ -10368,18 +10509,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int4 = torch.constant.int 4\n" " %false = torch.constant.bool false\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" -" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" +" %5 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" @@ -10395,18 +10536,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" @@ -10433,7 +10574,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -10473,7 +10614,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -10490,7 +10631,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" " %1 = torch.prim.If %0 -> (!torch.int) {\n" " %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" @@ -10500,8 +10641,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %2 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -10520,7 +10661,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" " %1 = torch.prim.If %0 -> (!torch.int) {\n" " %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" @@ -10530,13 +10671,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %2 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" +" %6 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %7 : !torch.bool\n" " }\n" " %5 = torch.prim.If %4 -> (!torch.int) {\n" @@ -10557,7 +10698,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" " %1 = torch.prim.If %0 -> (!torch.int) {\n" " %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" @@ -10567,20 +10708,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %2 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %8 : !torch.bool\n" " }\n" " %5 = torch.prim.If %4 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %8 : !torch.bool\n" " }\n" " %6 = torch.prim.If %5 -> (!torch.int) {\n" @@ -10601,7 +10742,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" " %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" @@ -10624,7 +10765,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" " %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" @@ -10638,7 +10779,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -10652,6 +10793,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int4 = torch.constant.int 4\n" " return %int4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.argmin\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %int0 = torch.constant.int 0\n" @@ -10690,6 +10835,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.min.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %none = torch.constant.none\n" @@ -10752,7 +10903,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -10763,7 +10914,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " %5 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.int\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %6 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" " %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" " torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" @@ -10771,9 +10922,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %9 = torch.prim.If %8 -> (!torch.int) {\n" -" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" " torch.prim.If %10 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -10784,7 +10935,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " torch.prim.If.yield %12 : !torch.int\n" " } else {\n" -" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" " %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" " torch.prim.If %11 -> () {\n" " torch.prim.If.yield\n" @@ -10905,8 +11056,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %2 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -11059,7 +11210,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" " torch.prim.If.yield %5 : !torch.int\n" " }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" @@ -11119,7 +11270,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" " %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" @@ -11140,7 +11291,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" " %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" @@ -11161,7 +11312,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" " %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" @@ -11181,7 +11332,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -11214,7 +11365,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -11245,8 +11396,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" " %6 = torch.prim.If %5 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -11257,7 +11408,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.atan\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" @@ -11270,7 +11421,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" " %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.int {\n" @@ -11297,7 +11448,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list>, %arg2: !torch.optional>) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" @@ -11314,7 +11492,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %str_0 = torch.constant.str \"AssertionError: \"\n" " %0 = torch.prim.Uninitialized : !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" " torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" @@ -11322,11 +11500,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %5 = torch.prim.If %4 -> (!torch.int) {\n" " torch.prim.If.yield %int7 : !torch.int\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %7 = torch.prim.If %6 -> (!torch.bool) {\n" " %9 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If.yield %9 : !torch.bool\n" @@ -11350,7 +11528,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.number) -> !torch.int {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index adf9182df788..281a827858b7 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -187,6 +187,358 @@ static SmallVector computeDimsOrderForMoveDim(int64_t srcDimInt, return dimsOrder; } +static bool parseEquation(const std::string &equation, + SmallVector> &inputTokens, + SmallVector &resultTokens) { + SmallVector inputToken; + size_t index = 0; + enum EquationVariable { kIsInput, kIsResult }; + EquationVariable currentVariable = kIsInput; + while (index < equation.size()) { + if (std::isalpha(equation[index])) { + if (currentVariable == kIsInput) { + inputToken.push_back(equation[index]); + } else { + resultTokens.push_back(equation[index]); + } + } else if (equation[index] == ',') { + inputTokens.push_back(inputToken); + inputToken.clear(); + } else if ((index < (equation.size() - 1)) && + (equation.substr(index, 2).find("->") != std::string::npos)) { + inputTokens.push_back(inputToken); + inputToken.clear(); + currentVariable = kIsResult; + index++; + } else { + return false; + } + index++; + } + return true; +} + +// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] => +// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] +static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, + Value input, int64_t batchDimsLength, + int64_t contractingDimsLength, + int64_t otherDimsLength, + int64_t reduceDimsLength, bool isLhs) { + auto inputType = input.getType().cast(); + auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + + reduceDimsLength; + SmallVector inputShapeTensor; + for (auto i = 0; i < inputRank; ++i) { + inputShapeTensor.emplace_back(rewriter.create( + loc, input, + rewriter.create(loc, + rewriter.getI64IntegerAttr(i)))); + } + + SmallVector outShapeTensor; + Value constOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto dimOffset = 0; + + auto appendDims = [&](int64_t dimLength) { + Value prod = constOne; + for (auto i = 0; i < dimLength; ++i) { + prod = rewriter.create(loc, prod, + inputShapeTensor[i + dimOffset]); + } + outShapeTensor.emplace_back(prod); + dimOffset += dimLength; + }; + + appendDims(batchDimsLength); + if (!isLhs) + appendDims(contractingDimsLength); + appendDims(otherDimsLength + reduceDimsLength); + if (isLhs) + appendDims(contractingDimsLength); + + auto outShapeValue = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), + outShapeTensor); + + auto outType = inputType.getWithSizesAndDtype(std::nullopt, + inputType.getOptionalDtype()); + return rewriter.create(loc, outType, input, + outShapeValue); +} + +// classify every dim token into different categories. Note that although we +// parse out reduce dims, we delay their execution until +// `performLastPermuteAndReduce`. +static void parseDimTokens( + SmallVector &lhsTokens, SmallVector &rhsTokens, + SmallVector &finalResultTokens, SmallVector &contractingDims, + SmallVector &lhsReduceDims, SmallVector &rhsReduceDims, + SmallVector &batchingDims, SmallVector &lhsOtherDims, + SmallVector &rhsOtherDims) { + llvm::SmallDenseSet lhsTokenSet(lhsTokens.begin(), lhsTokens.end()); + llvm::SmallDenseSet rhsTokenSet(rhsTokens.begin(), rhsTokens.end()); + llvm::SmallDenseSet finalResultTokenSet(finalResultTokens.begin(), + finalResultTokens.end()); + + for (size_t i = 0; i < lhsTokens.size(); ++i) { + bool rhsContains = rhsTokenSet.contains(lhsTokens[i]); + bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]); + // batching dim + if (rhsContains && finalResultConatins) { + batchingDims.push_back(lhsTokens[i]); + // reduce dim of lhs + } else if (!rhsContains && !finalResultConatins) { + lhsReduceDims.push_back(lhsTokens[i]); + // other dim of lhs + } else if (finalResultConatins) { + lhsOtherDims.push_back(lhsTokens[i]); + // contracting dim of lhs + } else if (rhsContains) { + contractingDims.push_back(lhsTokens[i]); + } + } + + for (size_t i = 0; i < rhsTokens.size(); ++i) { + bool lhsContains = lhsTokenSet.contains(rhsTokens[i]); + bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]); + // batching dim + if (lhsContains && finalResultConatins) { + // reduce dim of rhs + } else if (!lhsContains && !finalResultConatins) { + rhsReduceDims.push_back(rhsTokens[i]); + // other dim of rhs + } else if (finalResultConatins) { + rhsOtherDims.push_back(rhsTokens[i]); + // contracting dim of rhs + } else if (lhsContains) { + } + } +} + +static void generateIdealReusltDimTokens(SmallVector &batchingDims, + SmallVector &lhsOtherDims, + SmallVector &rhsOtherDims, + SmallVector &lhsReduceDims, + SmallVector &rhsReduceDims, + SmallVector &resultTokens) { + // generate ideal result dims, i.e., + // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, + // *rhsReduceDims] + resultTokens.insert(resultTokens.end(), batchingDims.begin(), + batchingDims.end()); + resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(), + lhsOtherDims.end()); + resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(), + lhsReduceDims.end()); + resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(), + rhsOtherDims.end()); + resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(), + rhsReduceDims.end()); +} + +static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, + Value input, SmallVector &dimTokens, + SmallVector &batchingDims, + SmallVector &contractingDims, + SmallVector &otherDims, + SmallVector &reduceDims, bool isLhs) { + auto inputType = input.getType().cast(); + llvm::SmallDenseMap dimTokenMap; + for (size_t idx = 0; idx < dimTokens.size(); ++idx) { + dimTokenMap[dimTokens[idx]] = idx; + } + + SmallVector permuteVec; + auto appendDims = [&](SmallVector dimTokens) { + for (auto d : dimTokens) { + permuteVec.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); + } + }; + + appendDims(batchingDims); + if (!isLhs) + appendDims(contractingDims); + appendDims(otherDims); + appendDims(reduceDims); + if (isLhs) + appendDims(contractingDims); + + Value dstDims = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + permuteVec); + auto outType = inputType.getWithSizesAndDtype(std::nullopt, + inputType.getOptionalDtype()); + return rewriter.create(loc, outType, input, dstDims); +} + +static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, + Value lhs, SmallVector &lhsTokens, + Value rhs, SmallVector &rhsTokens, + Value &result, + SmallVector &resultTokens, + SmallVector &finalResultTokens) { + auto lhsType = lhs.getType().cast(); + auto rhsType = rhs.getType().cast(); + + Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() + : rhsType.getOptionalDtype(); + + llvm::SmallDenseMap lhsDimShapeMap; + for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { + char d = lhsTokens[idx]; + lhsDimShapeMap[d] = rewriter.create( + loc, lhs, + rewriter.create(loc, + rewriter.getI64IntegerAttr(idx))); + } + llvm::SmallDenseMap rhsDimShapeMap; + for (size_t idx = 0; idx < rhsTokens.size(); ++idx) { + char d = rhsTokens[idx]; + rhsDimShapeMap[d] = rewriter.create( + loc, rhs, + rewriter.create(loc, + rewriter.getI64IntegerAttr(idx))); + } + + // parse batch, contracting, other, reduce dims of lhs and rhs + SmallVector contractingDims; + SmallVector lhsReduceDims; + SmallVector rhsReduceDims; + SmallVector lhsOtherDims; + SmallVector rhsOtherDims; + SmallVector batchingDims; + parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims, + lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims, + rhsOtherDims); + + llvm::SmallDenseMap outDimShapeMap; + auto generateOutDimShapeMap = [&](SmallVector &dims) { + for (auto d : dims) { + bool lhsContains = lhsDimShapeMap.count(d) > 0; + bool rhsContains = rhsDimShapeMap.count(d) > 0; + if (lhsContains && rhsContains) { + outDimShapeMap[d] = rewriter.create( + loc, lhsDimShapeMap[d], rhsDimShapeMap[d]); + } else if (lhsContains) { + outDimShapeMap[d] = lhsDimShapeMap[d]; + } else if (rhsContains) { + outDimShapeMap[d] = rhsDimShapeMap[d]; + } + } + }; + + generateOutDimShapeMap(contractingDims); + generateOutDimShapeMap(batchingDims); + generateOutDimShapeMap(lhsReduceDims); + generateOutDimShapeMap(rhsReduceDims); + generateOutDimShapeMap(lhsOtherDims); + generateOutDimShapeMap(rhsOtherDims); + + if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 && + rhsOtherDims.size() == 0) { + return rewriter.notifyMatchFailure( + loc, "Hadamard product is currently not supported"); + } + + // shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] + lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims, + contractingDims, lhsOtherDims, lhsReduceDims, + true); + // shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims] + rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims, + contractingDims, rhsOtherDims, rhsReduceDims, + false); + // shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] + lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(), + contractingDims.size(), lhsOtherDims.size(), + lhsReduceDims.size(), true); + // shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd] + rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(), + contractingDims.size(), rhsOtherDims.size(), + rhsReduceDims.size(), false); + + // perform matmul + auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType); + result = rewriter.create(loc, outType, lhs, rhs); + + // generate ideal result dims. + generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims, + lhsReduceDims, rhsReduceDims, resultTokens); + + // reshape matmul result to ideal shape: + // [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] => + // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, + // *rhsReduceDims] + SmallVector outShapeTensors; + for (char d : resultTokens) { + outShapeTensors.emplace_back(outDimShapeMap[d]); + } + + auto outResultShape = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())), + outShapeTensors); + result = rewriter.create( + loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result, + outResultShape); + return success(); +} + + +static Value performLastReduceAndPermute(PatternRewriter &rewriter, + Location loc, Type outType, + Value input, + SmallVector &inputTokens, + SmallVector &outTokens) { + auto inputType = input.getType().cast(); + + llvm::SmallDenseSet outTokenSet(outTokens.begin(), outTokens.end()); + SmallVector sumDims; + llvm::SmallDenseMap inputDimToIdx; + int64_t idx = 0; + for (size_t i = 0; i < inputTokens.size(); ++i) { + char d = inputTokens[i]; + if (!outTokenSet.contains(d)) { + sumDims.emplace_back(i); + } else { + inputDimToIdx[d] = idx++; + } + } + + if (sumDims.size() > 0) { + SmallVector sumDimsTensor; + for (auto d : sumDims) { + sumDimsTensor.emplace_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(d))); + } + auto sumDimsListValue = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + sumDimsTensor); + auto falseValue = rewriter.create( + loc, rewriter.getBoolAttr(false)); + auto noneValue = rewriter.create(loc); + input = rewriter.create( + loc, + inputType.getWithSizesAndDtype(std::nullopt, + inputType.getOptionalDtype()), + input, sumDimsListValue, falseValue, noneValue); + } + + SmallVector permuteDimsTensor; + for (auto d : outTokens) { + permuteDimsTensor.emplace_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(inputDimToIdx[d]))); + } + auto permuteDimsListValue = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), + permuteDimsTensor); + auto out = rewriter.create(loc, outType, input, + permuteDimsListValue); + return out; +} + namespace { /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the /// number of dimensions across which the max needs to be computed. @@ -246,6 +598,62 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { }; } // end namespace +namespace { +class DecomposeAtenTriuOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTriuOp op, + PatternRewriter &rewriter) const override { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + Value input = op.getSelf(); + auto inputType = input.getType().cast(); + if (!inputType.hasSizes() || !inputType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "should have shape and dtype"); + } + if (inputType.getSizes().size() < 2) { + return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2"); + } + + auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + Value cstZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value none = rewriter.create(loc); + + Value rowDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(-2)); + Value colDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(-1)); + Value rowSize = rewriter.create(loc, input, rowDim); + Value colSize = rewriter.create(loc, input, colDim); + + Value rowArange = rewriter.create( + loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + Value colArange = rewriter.create( + loc, baseType, colSize, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + Value unsqueezeRowArange = + rewriter.create(loc, baseType, rowArange, cstOne); + Value unsqueezeColArange = + rewriter.create(loc, baseType, colArange, cstZero); + + Value unsqueezeRowArangePlusDiagonal = rewriter.create( + loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne); + + Value condTensor = rewriter.create( + loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), condTensor, input, cstZero); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -572,6 +980,78 @@ class DecomposeAtenReshapeOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce +// operation and permute operation. Currently, this pass doesn't support +// Hadamard product. The basic idea is that: +// Step 1: split the string equation to input/result tokens and find +// batchingDims, contractingDims, otherDims and reduceDims. +// Step 2: permute and reshape input tensors suitable +// for matmul operations. +// Step 3: use AtenMatmulOp to get the result. +// Step 4: iteratively execute step 2 & 3 until we get the final result. +// Step 5: perform remaining permute and reduce operations. +// notice: support static shape only + +class DecomposeAtenEinsumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEinsumOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + std::string equation; + if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) { + return rewriter.notifyMatchFailure(op, "Unsupported value of equation"); + } + SmallVector resultTokens; + SmallVector> inputTokens; + if (!parseEquation(equation, inputTokens, resultTokens)) { + return rewriter.notifyMatchFailure( + op, "Unexpected character in equations encountered"); + } + + SmallVector inputTensors; + if (!getListConstructElements(op.getTensors(), inputTensors)) { + return rewriter.notifyMatchFailure( + op, "input should comes from a PrimListConstructOp"); + } + + auto allTensorHasSizes = [](Value tensor) { + auto type = tensor.getType().dyn_cast(); + if (!type || !type.hasSizes()) + return false; + return true; + }; + + if (!llvm::all_of(inputTensors, allTensorHasSizes)) { + return rewriter.notifyMatchFailure(op, + "all input tensors should have sizes"); + } + + SmallVector lhsTokens = inputTokens[0]; + Value lhs = inputTensors[0]; + Value result; + + for (size_t i = 1; i < inputTensors.size(); ++i) { + auto rhs = inputTensors[i]; + auto rhsTokens = inputTokens[i]; + SmallVector outTokens; + if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens, + result, outTokens, resultTokens))) { + return failure(); + } + lhs = result; + lhsTokens = outTokens; + } + + result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs, + lhsTokens, resultTokens); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). // To avoid overflow we use the following decomposition rule: @@ -784,12 +1264,13 @@ class DecomposeAten_LogSoftmaxBackwardDataOp }; } // namespace -// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`. +// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp` namespace { -class DecomposeAtenArgMaxOp : public OpRewritePattern { +template +class DecomposeAtenArgMinMaxOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenArgmaxOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); @@ -814,7 +1295,7 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern { .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. - // `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input + // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input // tensor is flattened to 1d tensor and then the reduction happens on the // 0th dimension. if (dim.getType().isa()) { @@ -829,13 +1310,14 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern { input = rewriter.create(loc, flattenType, input, dim, end); } - Value maxResult = - rewriter - .create(loc, valueTensorType, indicesTensorType, - input, dim, keepDim) - .getIndices(); - rewriter.replaceOp(op, maxResult); + Value resultArg = + rewriter + .create(loc, valueTensorType, indicesTensorType, + input, dim, keepDim) + .getIndices(); + + rewriter.replaceOp(op, resultArg); return success(); } }; @@ -1095,16 +1577,22 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace -// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations. +// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and +// prims.collapse operations. +// +// If input is a tensor of shape +// (*leading_dims, C*r*r, H, W), // -// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where -// leading_dims is of size N, then +// where leading_dims is of size N, then // X = pixel_shuffle(input, upscale_factor) // // gets replaced with -// A = input.reshape(*leading_dims, C, r, r, H, W) -// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2) -// X = B.reshape(*leading_dims, C, r*H, r*W) +// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W) +// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W) +// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2) +// # shape (*leading_dims, C, H, r, W, r) +// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W) +// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) // // 'r' above is referred to as the 'upscale factor' or just 'factor' below. namespace { @@ -1115,7 +1603,6 @@ class DecomposeAtenPixelShuffleOp LogicalResult matchAndRewrite(AtenPixelShuffleOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); Value inValue = op.getSelf(); auto inType = inValue.getType().cast(); @@ -1127,22 +1614,6 @@ class DecomposeAtenPixelShuffleOp auto inShape = maybeSizes.value(); auto inRank = inShape.size(); - // TODO support dynamic shapes, probably by lowering pixel_shuffle to linalg - // directly. Pixel shuffle does a reshape that is hard to recover - // through pure torch (view) ops, especially in dynamic cases. - // - // See: https://github.com/llvm/torch-mlir/issues/2559 - // - // For now, we just fail the decomposition here so that a sensible error is - // provided: - for (auto dimSize : inShape) { - if (dimSize == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "Currently we only decompose pixel_shuffle if the input tensor " - "is statically shaped"); - } - } - // The input tensor must have at least 3 dimensions: (1) the channel // dimension which gets smaller by 'factor*factor', (2) the H channel which // gets larger by 'factor' and (3) the W channel which get larger by @@ -1152,6 +1623,29 @@ class DecomposeAtenPixelShuffleOp return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank greater than 2."); + const auto inOptionalDType = inType.getOptionalDtype(); + + auto getTypeFromShape = [inOptionalDType](auto &&vals) { + // Get a vector of integers from a vector of Values. + auto getIntShape = [](auto &&vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (auto v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; + }; + + const auto intShape = getIntShape(vals); + return ValueTensorType::get(vals[0].getContext(), + llvm::ArrayRef(intShape), inOptionalDType); + }; + auto nLeadingDims = inRank - 3; // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead @@ -1169,106 +1663,94 @@ class DecomposeAtenPixelShuffleOp auto factor = op.getUpscaleFactor(); - Value factorSquared = rewriter.createOrFold(loc, factor, factor); + Value outC = rewriter.createOrFold(loc, inC, factorSquared); Value outH = rewriter.createOrFold(loc, inH, factor); Value outW = rewriter.createOrFold(loc, inW, factor); - // Shape of 'A' in the comment at the top - SmallVector prePermuteShape; - prePermuteShape.reserve(nLeadingDims + 5); - - // Shape of 'B' in the comment at the top. - SmallVector postPermuteShape; - postPermuteShape.reserve(nLeadingDims + 5); - - SmallVector outShape; - outShape.reserve(nLeadingDims + 3); - - SmallVector permutation; - permutation.reserve(nLeadingDims + 5); + SmallVector dimensionConstants; + dimensionConstants.reserve(inRank + 2); + for (unsigned i = 0; i < inRank + 2; ++i) { + dimensionConstants.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + } + SmallVector leadingDims; + leadingDims.reserve(nLeadingDims); for (unsigned i = 0; i < nLeadingDims; ++i) { - auto dimensionAttr = rewriter.getI64IntegerAttr(i); - Value dimensionValue = rewriter.create(loc, dimensionAttr); - Value leadingDimSize = - rewriter.createOrFold(loc, inValue, dimensionValue); - prePermuteShape.push_back(leadingDimSize); - postPermuteShape.push_back(leadingDimSize); - outShape.push_back(leadingDimSize); - permutation.push_back(dimensionValue); - + Value leadingDimSize = rewriter.createOrFold( + loc, inValue, dimensionConstants[i]); + leadingDims.push_back(leadingDimSize); } - const auto inOptionalDType = inType.getOptionalDtype(); + SmallVector partiallyExpandedShape = leadingDims; + partiallyExpandedShape.append({outC, factorSquared, inH, inW}); - auto getTypeFromShape = [inOptionalDType](auto &&vals) { - // Get a vector of integers from a vector of Values. - auto getIntShape = [](auto &&vals) { - SmallVector shape; - shape.reserve(vals.size()); - for (auto v : vals) { - int64_t cst_val; - if (matchPattern(v, m_TorchConstantInt(&cst_val))) { - shape.push_back(cst_val); - } else { - shape.push_back(kUnknownSize); - } - } - return shape; - }; + SmallVector prePermuteShape = leadingDims; + prePermuteShape.append({outC, factor, factor, inH, inW}); - const auto intShape = getIntShape(vals); - return ValueTensorType::get(vals[0].getContext(), - llvm::ArrayRef(intShape), inOptionalDType); - }; - - prePermuteShape.insert(prePermuteShape.end(), - {outC, factor, factor, inH, inW}); + SmallVector postPermuteShape = leadingDims; + postPermuteShape.append({outC, inH, factor, inW, factor}); - postPermuteShape.insert(postPermuteShape.end(), - {outC, inH, factor, inW, factor}); + SmallVector partiallyCollapsedShape = leadingDims; + partiallyCollapsedShape.append({outC, inH, factor, outW}); - outShape.insert(outShape.end(), {outC, outH, outW}); + SmallVector outShape = leadingDims; + outShape.append({outC, outH, outW}); + SmallVector permutation{dimensionConstants.begin(), + dimensionConstants.begin() + nLeadingDims}; SmallVector permutationTail{0, 3, 1, 4, 2}; for (uint64_t d : permutationTail) { - permutation.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(nLeadingDims + d))); + permutation.push_back(dimensionConstants[nLeadingDims + d]); } - auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); - - Value shapeA = - rewriter.create(loc, listType, prePermuteShape); - - Value A = rewriter.create( - loc, getTypeFromShape(prePermuteShape), inValue, shapeA); - Value permuteDimsOrder = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation); - Value B = rewriter.create( - loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder); + // Split input channel inC -> (inC, factorSquared) + auto partiallyExpanded = + rewriter + .create( + loc, getTypeFromShape(partiallyExpandedShape), inValue, + dimensionConstants[nLeadingDims], outC) + .getResult(); + + // Split new dimension factorSquared -> (factor, factor) + auto fullyExpanded = rewriter.create( + loc, getTypeFromShape(prePermuteShape), partiallyExpanded, + dimensionConstants[nLeadingDims + 1], factor); + + // Perform the permutation + auto permuted = + rewriter.create(loc, getTypeFromShape(postPermuteShape), + fullyExpanded, permuteDimsOrder); - Value outShapeList = - rewriter.create(loc, listType, outShape); + // Collapse final 2 dimension + auto partiallyCollapsed = rewriter.create( + loc, getTypeFromShape(partiallyCollapsedShape), permuted, + dimensionConstants[nLeadingDims + 3], + dimensionConstants[nLeadingDims + 4]); + + // Collapse back to original rank + rewriter.replaceOpWithNewOp( + op, op.getType(), partiallyCollapsed, + dimensionConstants[nLeadingDims + 1], + dimensionConstants[nLeadingDims + 2]); - rewriter.replaceOpWithNewOp(op, op.getType(), B, - outShapeList); return success(); } }; } // namespace // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) -static Value -getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { +static Value getRelu6Results(PatternRewriter &rewriter, Location loc, + Value input) { BaseTensorType inputType = input.getType().cast(); Value relu = rewriter.create(loc, inputType, input); @@ -1815,7 +2297,7 @@ class DecomposeAtenUnflattenIntOp auto inputTensorType = self.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure(op, - "Expected input type having sizes"); + "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); @@ -1851,7 +2333,7 @@ class DecomposeAtenUnflattenIntOp Value dimSize = rewriter.create(loc, self, /*dim=*/dimValue); if (i == dimInt) { - int64_t inferredSizeInt = inputShape[i]; + int64_t inferredSizeInt = inputShape[i]; int64_t inferredDim; for (unsigned j = 0; j < sizesInts.size(); ++j) { if (sizesInts[j] == -1) { @@ -1865,11 +2347,9 @@ class DecomposeAtenUnflattenIntOp } } if (inferred) { - Value inferredSize = - rewriter.create( + Value inferredSize = rewriter.create( loc, rewriter.getI64IntegerAttr(inferredSizeInt)); - newSizes.insert( - newSizes.begin() + inferredDim + i, inferredSize); + newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize); } } else { newSizes.push_back(dimSize); @@ -4097,6 +4577,21 @@ class DecomposeAtenClampMinOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.clamp_min.Tensor` op into `aten.clamp.Tensor` op. +class DecomposeAtenClampMinTensorOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenClampMinTensorOp op, + PatternRewriter &rewriter) const override { + Value constantNone = rewriter.create(op.getLoc()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getMin(), /*max=*/constantNone); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.clampMax` op into `aten.clamp` op. class DecomposeAtenClampMaxOp : public OpRewritePattern { @@ -4112,7 +4607,7 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern { } // namespace namespace { -class DecomposeAtenCosineSimilarityOp +class DecomposeAtenCosineSimilarityOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCosineSimilarityOp op, @@ -4139,7 +4634,7 @@ class DecomposeAtenCosineSimilarityOp indexBroadcastShapeTorchList); // Compute the mul of A and B - Value dotProduct = + Value dotProduct = rewriter.create(loc, broadcastType, x1, x2); Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); @@ -4150,17 +4645,17 @@ class DecomposeAtenCosineSimilarityOp loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); - + // Compute the norm of A and B - Value ord = rewriter.create(loc, - rewriter.getF64FloatAttr(2.0)); + Value ord = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); Value normA = rewriter.create( loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); Value normB = rewriter.create( loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); - + // Compute the product of the norms Value normProduct = rewriter.create(loc, op.getType(), normA, normB); @@ -5852,7 +6347,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal>(patterns); + addPatternIfTargetOpIsIllegal>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -5873,6 +6369,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -5904,6 +6401,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -5961,6 +6459,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6ac2e5b2b4c8..b0cd84ff6bfd 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -385,6 +385,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -414,6 +415,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -458,6 +460,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -500,6 +503,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 5620668a82c3..5bd254d72be1 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -88,6 +88,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Long; if (type.isSignedInteger(32)) return torch_upstream::ScalarType::Int; + if (type.isSignedInteger(16)) + return torch_upstream::ScalarType::Short; if (type.isSignlessInteger(1)) return torch_upstream::ScalarType::Bool; if (type.isBF16()) @@ -122,17 +124,18 @@ Type Torch::getTypeForTorchType( FailureOr Torch::getTypeForScalarType(MLIRContext *context, - torch_upstream::ScalarType dtypeInt, - mlir::IntegerType::SignednessSemantics signedness) { + torch_upstream::ScalarType dtypeInt) { switch (dtypeInt) { case torch_upstream::ScalarType::Float: return Float32Type::get(context); case torch_upstream::ScalarType::Double: return Float64Type::get(context); case torch_upstream::ScalarType::Long: - return IntegerType::get(context, 64, signedness); + return IntegerType::get(context, 64, mlir::IntegerType::Signed); case torch_upstream::ScalarType::Int: - return IntegerType::get(context, 32, signedness); + return IntegerType::get(context, 32, mlir::IntegerType::Signed); + case torch_upstream::ScalarType::Short: + return IntegerType::get(context, 16, mlir::IntegerType::Signed); case torch_upstream::ScalarType::Bool: return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: @@ -142,7 +145,7 @@ Torch::getTypeForScalarType(MLIRContext *context, case torch_upstream::ScalarType::Byte: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: - return mlir::IntegerType::get(context, 8, signedness); + return mlir::IntegerType::get(context, 8, mlir::IntegerType::Signed); case torch_upstream::ScalarType::ComplexHalf: return mlir::ComplexType::get(Float16Type::get(context)); case torch_upstream::ScalarType::ComplexFloat: @@ -243,15 +246,16 @@ bool Torch::isViewLikeOp(Operation *op) { TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, - AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp>(op); + PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp, + AtenPixelShuffleOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, float value, Type dtype) { // Creating constants satisfying backend contract. - if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) || - dtype.isInteger(1)) + if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(16) || + dtype.isInteger(8) || dtype.isInteger(1)) return rewriter.create( loc, rewriter.getI64IntegerAttr((int64_t)value)); if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16()) diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 0be0ec8ba3ea..ace6c1a40e74 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -22,6 +22,7 @@ #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" #include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" @@ -47,8 +48,8 @@ void mlir::torch::registerOptionalInputDialects( void mlir::torch::registerAllPasses() { mlir::torch::registerTorchPasses(); mlir::torch::registerTorchConversionPasses(); - mlir::torch::registerConversionPasses(); + mlir::torch::onnx_c::registerTorchOnnxToTorchPasses(); mlir::torch::TMTensor::registerPasses(); #ifdef TORCH_MLIR_ENABLE_REFBACKEND diff --git a/projects/CMakeLists.txt b/projects/CMakeLists.txt new file mode 100644 index 000000000000..4b54be65a79d --- /dev/null +++ b/projects/CMakeLists.txt @@ -0,0 +1,53 @@ +include(AddMLIRPython) + +# Configure PyTorch if we have any features enabled which require it. +if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) + message(STATUS "Enabling PyTorch C++ dep (features depend on it)") + include(TorchMLIRPyTorch) + + TorchMLIRProbeForPyTorchInstall() + if(TORCH_MLIR_USE_INSTALLED_PYTORCH) + TorchMLIRConfigurePyTorch() + else() + # Assume it is a sibling to the overall project. + set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch") + message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}") + if(NOT EXISTS "${Torch_DIR}") + message(FATAL_ERROR "Without TORCH_MLIR_USE_INSTALLED_PYTORCH, expected to find Torch configuration at ${Torch_DIR}, which does not exist") + endif() + endif() + + find_package(Torch 1.11 REQUIRED) + + set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen) + + include_directories(BEFORE + ${TORCH_INCLUDE_DIRS} + ${Python3_INCLUDE_DIRS} + ) + link_directories("${TORCH_INSTALL_PREFIX}/lib") + message(STATUS "TORCH_CXXFLAGS is = ${TORCH_CXXFLAGS}") + if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux" AND NOT TORCH_CXXFLAGS) + message(WARNING + "When building on Linux TORCH_CXXFLAGS are almost always required but were not detected. " + "It is very likely this this will produce a non-functional installation. " + "See notes in build_tools/cmake/TorchMLIRPyTorch.cmake.") + endif() + message(STATUS "TORCH_LIBRARIES = ${TORCH_LIBRARIES}") +endif() + +# Include jit_ir_common if the jit_ir importer or LTC is enabled, +# since they both require it. +if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) + add_subdirectory(jit_ir_common) +endif() + +# Include LTC. +if(TORCH_MLIR_ENABLE_LTC) + add_subdirectory(ltc) +endif() + +# Include overall PT1 project. +if(TORCH_MLIR_ENABLE_PROJECT_PT1) + add_subdirectory(pt1) +endif() diff --git a/projects/jit_ir_common/CMakeLists.txt b/projects/jit_ir_common/CMakeLists.txt new file mode 100644 index 000000000000..f0a3ff596748 --- /dev/null +++ b/projects/jit_ir_common/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc/jit_ir_importer) diff --git a/projects/jit_ir_common/csrc/jit_ir_importer/CMakeLists.txt b/projects/jit_ir_common/csrc/jit_ir_importer/CMakeLists.txt new file mode 100644 index 000000000000..b5f24fb80e8c --- /dev/null +++ b/projects/jit_ir_common/csrc/jit_ir_importer/CMakeLists.txt @@ -0,0 +1,27 @@ +# Static library with core functionality. +# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build) +# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376 +add_library(TorchMLIRJITIRImporter STATIC + class_annotator.cpp + function_importer.cpp + node_importer.cpp + ivalue_importer.cpp + torch_to_mlir_utils.cpp + ) +message(STATUS "Linking TorchMLIRJITImporter with ${TORCH_LIBRARIES}") +target_link_libraries(TorchMLIRJITIRImporter + TorchMLIRAggregateCAPI + ${TORCH_LIBRARIES} + ) +# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...") +target_include_directories(TorchMLIRJITIRImporter PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/.. +) +set_target_properties(TorchMLIRJITIRImporter PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" + OUTPUT_NAME lib_jit_ir_importer + PREFIX "" + SUFFIX ".a" + CXX_VISIBILITY_PRESET "default" + COMPILE_FLAGS "${TORCH_CXXFLAGS}" + ) diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.h similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h rename to projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.h diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/function_importer.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/function_importer.cpp diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h b/projects/jit_ir_common/csrc/jit_ir_importer/function_importer.h similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h rename to projects/jit_ir_common/csrc/jit_ir_importer/function_importer.h diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h b/projects/jit_ir_common/csrc/jit_ir_importer/import_options.h similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h rename to projects/jit_ir_common/csrc/jit_ir_importer/import_options.h diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.cpp similarity index 98% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.cpp index 75013d5ee9a5..ef02096eb340 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.cpp @@ -190,7 +190,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), mlirRegionCreate()); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); - mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr)); + mlirRegionAppendOwnedBlock(nnModuleRegion, + mlirBlockCreate(0, nullptr, nullptr)); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); InserterGuard inserterGuard(importBlock, nnModule); @@ -491,8 +492,9 @@ void IValueImporter::importClassType(c10::ClassType *classType) { toMlirNamedAttribute( "name", mlirStringAttrGet( context, toMlirStringRef(classAttribute.getName()))), - toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType( - loc, classAttribute.getType(), importOptions))), + toMlirNamedAttribute( + "type", mlirTypeAttrGet(getMlirTypeFromTorchType( + loc, classAttribute.getType(), importOptions))), isPrivate); } diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h b/projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.h similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h rename to projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.h diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/mlir_utils.h b/projects/jit_ir_common/csrc/jit_ir_importer/mlir_utils.h similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/mlir_utils.h rename to projects/jit_ir_common/csrc/jit_ir_importer/mlir_utils.h diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.cpp similarity index 93% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/node_importer.cpp index 15cffedbe834..0bb4722fcf77 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.cpp @@ -41,10 +41,9 @@ class NodeImporter { const ImportOptions &importOptions = {}); private: - MlirBlock - createBlockFor(Block *jitBlock, - c10::optional> blockArgTypes, - const ImportOptions &importOptions = {}); + MlirBlock createBlockFor(Block *jitBlock, + c10::optional> blockArgTypes, + const ImportOptions &importOptions = {}); void mapValue(Value *jitValue, MlirValue value); void mapResults(Node *node, MlirOperation operation); MlirValue lookupMappedValue(Value *jitValue); @@ -269,9 +268,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, terminatorOperandTypes, /*userAllowsRefinement=*/false)); }; - mlirRegionAppendOwnedBlock( - mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); + mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0), + importBlock(node->blocks()[0], createTerminator, + c10::nullopt, importOptions)); return; } @@ -290,12 +289,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, resultTypes, /*userAllowsRefinement=*/false)); }; - mlirRegionAppendOwnedBlock( - mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); - mlirRegionAppendOwnedBlock( - mlirOperationGetRegion(operation, 1), - importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions)); + mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0), + importBlock(node->blocks()[0], createTerminator, + c10::nullopt, importOptions)); + mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1), + importBlock(node->blocks()[1], createTerminator, + c10::nullopt, importOptions)); return; } @@ -303,8 +302,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, auto classType = node->input(0)->type()->cast(); auto methodName = node->s(c10::attr::name); torch::jit::Function *function = classType->findMethod(methodName); - MlirType calleeType = - getFunctionTypeFromSchema(context, function->getSchema(), importOptions); + MlirType calleeType = getFunctionTypeFromSchema( + context, function->getSchema(), importOptions); std::vector expectedTypes; for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) { expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i)); @@ -361,10 +360,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, } } -MlirBlock NodeImporter::importBlock( - Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes, - const ImportOptions &importOptions) { +MlirBlock +NodeImporter::importBlock(Block *jitBlock, CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes, + const ImportOptions &importOptions) { MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions); for (Node *node : jitBlock->nodes()) { importNode(node, block, importOptions); @@ -434,5 +433,6 @@ torch_mlir::importBlock(MlirContext context, Block *jitBlock, c10::optional> blockArgTypes, const ImportOptions &importOptions) { NodeImporter importer(context); - return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions); + return importer.importBlock(jitBlock, createTerminator, blockArgTypes, + importOptions); } diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.h similarity index 85% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h rename to projects/jit_ir_common/csrc/jit_ir_importer/node_importer.h index dd01444f415a..7fce8b988c45 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h +++ b/projects/jit_ir_common/csrc/jit_ir_importer/node_importer.h @@ -36,11 +36,11 @@ using CreateTerminatorFn = /// are required to be for correctness. The code will internally attempt to /// adjust the types to the block argument types. /// TODO: Formalize what type conversions are allowed here. -MlirBlock importBlock( - MlirContext context, torch::jit::Block *jitBlock, - CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes = c10::nullopt, - const ImportOptions &importOptions = {}); +MlirBlock +importBlock(MlirContext context, torch::jit::Block *jitBlock, + CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes = c10::nullopt, + const ImportOptions &importOptions = {}); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp rename to projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.cpp diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h b/projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.h similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h rename to projects/jit_ir_common/csrc/jit_ir_importer/torch_to_mlir_utils.h diff --git a/projects/ltc/CMakeLists.txt b/projects/ltc/CMakeLists.txt new file mode 100644 index 000000000000..892faabd7eb8 --- /dev/null +++ b/projects/ltc/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc/base_lazy_backend) diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt similarity index 76% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt rename to projects/ltc/csrc/base_lazy_backend/CMakeLists.txt index 2087f99eb53f..eee3044f0fc9 100644 --- a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt @@ -2,30 +2,6 @@ # Setup PyTorch/LTC #------------------------------------------------------------------------------- -include(TorchMLIRPyTorch) - -TorchMLIRProbeForPyTorchInstall() -if(TORCH_MLIR_USE_INSTALLED_PYTORCH) - TorchMLIRConfigurePyTorch() -else() - # Assume it is a sibling to the overall project. - set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch") - message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}") -endif() - -find_package(Torch 1.11 REQUIRED) - -set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen) - -include_directories(BEFORE - ${TORCH_INCLUDE_DIRS} - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_BINARY_DIR} - ${Python3_INCLUDE_DIRS} - ${PROJECT_SOURCE_DIR}/projects/pt1/python -) -link_directories("${TORCH_INSTALL_PREFIX}/lib") - set(LTC_GENERATED generated/LazyNativeFunctions.cpp generated/RegisterLazy.cpp @@ -80,6 +56,12 @@ add_library(torch_mlir_ltc_backend SHARED utils/tensor_utils.cpp ) target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) +# Includes are resolved relative to csrc (i.e. #include "base_lazy_backend/..."). +# Add both the source and generated include directories. +target_include_directories(torch_mlir_ltc_backend PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_BINARY_DIR}/.. +) add_dependencies(torch_mlir_ltc_backend TorchMLIRJITIRImporter @@ -112,13 +94,13 @@ add_custom_command( add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/) add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/generated/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/) add_custom_command( @@ -129,7 +111,7 @@ add_custom_command( add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/ops/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/) add_custom_command( @@ -140,5 +122,5 @@ add_custom_command( add_custom_command( TARGET torch_mlir_ltc_backend POST_BUILD COMMAND cp - ${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/*.h + ${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/utils/*.h ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/) diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/README.md b/projects/ltc/csrc/base_lazy_backend/README.md similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/README.md rename to projects/ltc/csrc/base_lazy_backend/README.md diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp b/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp rename to projects/ltc/csrc/base_lazy_backend/backend_impl.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h b/projects/ltc/csrc/base_lazy_backend/backend_impl.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h rename to projects/ltc/csrc/base_lazy_backend/backend_impl.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.cpp b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.cpp rename to projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.h b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.h rename to projects/ltc/csrc/base_lazy_backend/dynamic_ir.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h b/projects/ltc/csrc/base_lazy_backend/ir_builder.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h rename to projects/ltc/csrc/base_lazy_backend/ir_builder.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp similarity index 99% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp rename to projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp index 4823b4929ab1..7e6f40c5c2e9 100644 --- a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -21,8 +21,8 @@ #include "mlir-c/IR.h" #include "mlir-c/Pass.h" -#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" #include "backend_impl.h" +#include "jit_ir_importer/function_importer.h" #include "mlir_lowering_context.h" #include "mlir_node.h" #include "utils/debug.h" diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h rename to projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp rename to projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp rename to projects/ltc/csrc/base_lazy_backend/mlir_node.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h b/projects/ltc/csrc/base_lazy_backend/mlir_node.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h rename to projects/ltc/csrc/base_lazy_backend/mlir_node.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp rename to projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h rename to projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp b/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp rename to projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h b/projects/ltc/csrc/base_lazy_backend/ops/device_data.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h rename to projects/ltc/csrc/base_lazy_backend/ops/device_data.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/generic.cpp b/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/generic.cpp rename to projects/ltc/csrc/base_lazy_backend/ops/generic.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/generic.h b/projects/ltc/csrc/base_lazy_backend/ops/generic.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/generic.h rename to projects/ltc/csrc/base_lazy_backend/ops/generic.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/index.cpp b/projects/ltc/csrc/base_lazy_backend/ops/index.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/index.cpp rename to projects/ltc/csrc/base_lazy_backend/ops/index.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/index.h b/projects/ltc/csrc/base_lazy_backend/ops/index.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/index.h rename to projects/ltc/csrc/base_lazy_backend/ops/index.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.cpp b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.cpp rename to projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.h b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/ivalue.h rename to projects/ltc/csrc/base_lazy_backend/ops/ivalue.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/split.cpp b/projects/ltc/csrc/base_lazy_backend/ops/split.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/split.cpp rename to projects/ltc/csrc/base_lazy_backend/ops/split.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/split.h b/projects/ltc/csrc/base_lazy_backend/ops/split.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/split.h rename to projects/ltc/csrc/base_lazy_backend/ops/split.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h b/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h rename to projects/ltc/csrc/base_lazy_backend/ops/to_copy.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.cpp b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.cpp rename to projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.h b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/unbind_int.h rename to projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp rename to projects/ltc/csrc/base_lazy_backend/shape_inference.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp b/projects/ltc/csrc/base_lazy_backend/tensor.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp rename to projects/ltc/csrc/base_lazy_backend/tensor.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/tensor.h b/projects/ltc/csrc/base_lazy_backend/tensor.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/tensor.h rename to projects/ltc/csrc/base_lazy_backend/tensor.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/debug.h b/projects/ltc/csrc/base_lazy_backend/utils/debug.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/debug.h rename to projects/ltc/csrc/base_lazy_backend/utils/debug.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/exception.h b/projects/ltc/csrc/base_lazy_backend/utils/exception.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/exception.h rename to projects/ltc/csrc/base_lazy_backend/utils/exception.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp rename to projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h rename to projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h rename to projects/ltc/csrc/base_lazy_backend/utils/string_utils.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h rename to projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp rename to projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp diff --git a/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h similarity index 100% rename from projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h rename to projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 85822043bf13..0454c47e9f37 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -29,17 +29,9 @@ # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic", + "IscloseStaticModuleTrue_basic" } -if torch_version_for_comparison() >= version.parse("2.2.0.dev20230926"): - LINALG_XFAIL_SET |= { - "Conv2dWithPaddingDilationStrideStaticModule_grouped", - "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "Convolution2DGroupsStatic_basic", - "ConvolutionModule2DGroups_basic", - } - TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors @@ -89,6 +81,7 @@ #ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777) "UpSampleNearest2dDynamicFactor_basic", "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", #ERROR: value (-56) is not equal to golden value (200) "AtenIntTensorByteDtypeModule_basic", # ERROR: assert isinstance(e, FakeTensor) @@ -99,6 +92,8 @@ # ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::squeeze.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. "PrimsSqueezeModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", + "SplitDimStaticModule_basic", + "SplitDimDynamicModule_basic", # ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::view_of.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. "PrimsViewOfModule_basic", @@ -352,14 +347,6 @@ 'OneHotModule_basic', } -if torch_version_for_comparison() >= version.parse("2.2.0.dev20230926"): - TORCHDYNAMO_XFAIL_SET |= { - "Conv2dWithPaddingDilationStrideStaticModule_grouped", - "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "Convolution2DGroupsStatic_basic", - "ConvolutionModule2DGroups_basic", - } - TORCHDYNAMO_CRASHING_SET = { # No upstream decompositions. # %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) @@ -615,6 +602,9 @@ "EmptyLikeModule_int", "ExpandAsIntModule_basic", "ExpandModule_basic", + "EinsumStaticModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticContractRhsModule_basic", "Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat64_basic", "Fill_TensorFloat64WithInt64_basic", @@ -1016,48 +1006,112 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", - "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic", - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "AddCDiv_Module_basic", + "AddCDivModule_basic", + "AddCMul_Module_basic", + "AddCMulModule_basic", + "Add_Module_basic", "AliasModule_basic", - "MaxPool2dEmptyStrideStaticModule_basic", + "ArangeDtypeFloatModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeIntModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeNegativeStartIntModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartIntModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartNegativeStepIntModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutModule_basic", + "ArangeStartOutViewModule_basic", + "ArangeStartStepFloatModule_basic", + "ArangeStartStepIntModule_basic", + "ArangeZeroElementOutputModule_basic", + "ArgmaxModule_keepDim", + "ArgmaxModule_with_dim", + "AtenComplex64Module_basic", + "AtenEyeMModuleCPUDevice_basic", + "AtenEyeMModuleDefaultDtype_basic", + "AtenEyeMModuleFalsePinMemory_basic", + "AtenEyeMModuleFloat2D_basic", + "AtenEyeMModuleInt2D_basic", + "AtenEyeModuleCPUDevice_basic", + "AtenEyeModuleDefaultDtype_basic", + "AtenEyeModuleFalsePinMemory_basic", + "AtenEyeModuleFloat2D_basic", + "AtenEyeModuleInt2D_basic", + "AtenRoundIntModule_basic", + "AtenToDeviceModule_basic", + "AtenToDtypeModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "BaddbmmDynamicModule_basic", + "BaddbmmStaticModule_basic", + "BaddbmmWithAlphaBetaModule_basic", + "BaddbmmWithAlphaModule_basic", + "BaddbmmWithBetaModule_basic", + "BatchNorm1DModule_basic", + "BatchNorm1DStaticShapeModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BmmFloatModule_basic", + "BoolTensorHandleSignless_basic", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorReturnTrueModule_basic", + "BroadcastDifferentRankSameFinalShapeModule_basic", + "BroadcastDifferentRankWithMinusOneModule_basic", + "BroadcastListConstructWithMinusOneModule_basic", + "BroadcastToDifferentRankNotOneStaticModule_basic", + "BroadcastToDifferentRankStaticModule_basic", + "BroadcastToSameRankStaticModule_basic", + "BroadcastZeroRankInputStaticModule_basic", + "BucketizeTensorStaticFloatModule_basic", + "BucketizeTensorStaticModule_basic", + "ChunkListUnpack_Module_basic", + "ChunkListUnpackUneven_Module_basic", "ConstantBoolParameterModule_basic", - "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", - "ElementwiseCloneModule_basic", - "ElementwiseUnaryModule_basic", - "ElementwiseBinaryModule_basic", - "ElementwiseSigmoidModule_basic", - "ElementwiseExpModule_basic", - "ElementwiseReluModule_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseEluModule_basic", - "ElementwiseEluNonDefaultModule_basic", - "ElementwiseFloorModule_basic", - "ElementwiseFloorIntModule_basic", - "ElementwiseLogModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ElementwiseMinimumModule_basic", - "ElementwiseMinimumIntModule_basic", - "ElementwiseMinOtherIntModule_basic", - "ElementwiseMinOtherModule_basic", - "ElementwiseMaximumModule_basic", - "ElementwiseMaximumIntModule_basic", - "ElementwiseSinModule_basic", - "ElementwiseCosModule_basic", + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", + "ContiguousModule_basic", + "Conv1dNoPaddingGroupModule_basic", + "Conv1dNoPaddingModule_basic", + "Conv2dBiasNoPaddingModule_basic", + "Conv2dNoPaddingModule_basic", + "Conv2dWithPaddingDilationStrideModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Conv2dWithPaddingModule_basic", + "Convolution2DGroupsStatic_basic", + "Convolution2DStaticModule_basic", + "DetachModule_basic", + "DropoutEvalFloatModule_basic", + "DropoutEvalIntModule_basic", + "DropoutModule_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "ElementwiseAbsModule_basic", + "ElementwiseAcosModule_basic", "ElementwiseAcosTensorFloatModule_basic", + "ElementwiseAddModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalarInt64Module_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAsinTensorFloatModule_basic", "ElementwiseAtan2TensorFloatModule_basic", - "ElementwiseClampMaxModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampIntModule_basic", - "ElementwiseMaxOtherIntModule_basic", - "ElementwiseMaxOtherModule_basic", + "ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", @@ -1074,429 +1128,357 @@ "ElementwiseAtenLogicalXorOpModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", - "GluStaticModule_basic", - "ViewDoubleMergeStaticModule_basic", - "ViewCollapseOnesMiddleModule_basic", - "ViewFiveTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", - "ViewExpandOnesMiddleOppModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "TanhBackward_basic", - "HardtanhBackward_basic", - "ElementwiseAddModule_basic", - "ReturnThreeTensorFloat32_basic", - "AddCMulModule_basic", - "AddCDivModule_basic", - "SqueezeModule_broadcast", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "BoolTensorHandleSignless_basic", - "ElementwiseRsqrtModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SqueezeModule_static", - "SqueezeModule_noUnitDim", - "SqueezeModule_allUnitDim", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "AtenToDeviceModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeDimModule_unitDim", - "ReturnTwoTensorF32I64_basic", - "ElementwiseSignModule_basic", - "ElementwisePowModule_basic", - "ElementwisePowScalarModule_basic", - "ElementwisePowTensorBroadcastModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", - "ElementwisePowTensorModule_basic", - # FIXME FXML-3631 - # "ElementwisePowTensorStaticModule_basic", - "AtenToDtypeModule_basic", - "BmmFloatModule_basic", - "MmDagModule_basic", - "Matmul4dStatic_basic", - "Matmul_dot", - "Matmul_3d", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", + "ElementwiseAtenWhereSelfModule_basic", + "ElementwiseBinaryModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", "ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt64Module_basic", - "ElementwiseOrTensorStaticShapeModule_basic", - "ElementwiseOrTensorModule_basic", "ElementwiseBitwiseOrModule_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseCeilModule_basic", + "ElementwiseClampIntModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseEqBoolScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqFloatTensorModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseFlattenBroadcastModule_basic", + "ElementwiseFloorIntModule_basic", + "ElementwiseFloorModule_basic", "ElementwiseGeFloatIntScalarModule_basic", "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeFloatTensorModule", - "ElementwiseGeIntTensorModule_basic", "ElementwiseGeFloatTensorModule_basic", "ElementwiseGeIntScalarModule_basic", "ElementwiseGeIntTensorModule_basic", - "ElementwiseGeIntScalarModule_basic", + "ElementwiseGeluModule_basic", "ElementwiseGeMixedIntScalarModule_basic", "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtIntScalarModule_basic", "ElementwiseGtIntTensorModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseIsinfModule_basic", + "ElementwiseIsnanModule_basic", + "ElementwiseLeakyReluModule_basic", + "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLeFloatIntScalarModule_basic", "ElementwiseLeFloatScalarModule_basic", + "ElementwiseLeFloatTensorModule_basic", + "ElementwiseLeFloatTensorNanModule_basic", "ElementwiseLeIntScalarModule_basic", + "ElementwiseLeIntTensorModule_basic", "ElementwiseLeMixedIntScalarModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", + "ElementwiseLog2Module_basic", + "ElementwiseLogModule_basic", "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatScalarModule_basic", "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtIntScalarModule_basic", "ElementwiseLtIntTensorModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatTensorModule_basic", - "ElementwiseEqIntTensorModule_basic", - "ElementwiseNeFloatScalarModule_basic", - "ElementwiseNeFloatTensorModule_basic", - "ElementwiseNeFloatTensorStaticModule_basic", - "ElementwiseNeIntTensorModule_basic", - "ElementwiseNeIntTensorStaticModule_basic", - "ElementwiseMulScalarModule_int", - "ElementwiseMulScalarModule_float", - "ElementwiseMulTensorIntModule_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseAtenDivIntScalarModule_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseMaximumIntModule_basic", + "ElementwiseMaximumModule_basic", + "ElementwiseMaxOtherIntModule_basic", + "ElementwiseMaxOtherModule_basic", + "ElementwiseMinimumIntModule_basic", + "ElementwiseMinimumModule_basic", + "ElementwiseMinOtherIntModule_basic", + "ElementwiseMinOtherModule_basic", + "ElementwiseMulScalarModule_basic", "ElementwiseMulScalarModule_float", - "ElementwiseCeilModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseNeFloatScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseNeFloatTensorStaticModule_basic", + "ElementwiseNegModule_basic", + "ElementwiseNeIntScalarModule_basic", + "ElementwiseNeIntTensorModule_basic", + "ElementwiseNeIntTensorStaticModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwisePowModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorModule_basic", "ElementwiseReciprocalModule_basic", - "ElementwiseIsnanModule_basic", - "RsubIntModule_basic", - "RsubIntModule_noalpha_basic", - "RsubIntStaticModule_noalpha_basic", - "ElementwiseIsinfModule_basic", - "TypePromotionAlphaWiderModule_basic", - "Conv1dNoPaddingModule_basic", - "Conv1dNoPaddingGroupModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "Conv2dWithPaddingDilationStrideStaticModule_grouped", - "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", - "FlattenStaticModule_basic", - "UnflattenStaticModule_basic", - "FlattenRank0Module_basic", - "ElementwiseFlattenBroadcastModule_basic", - "SquareModule_basic", - "MaxPool2dStaticModule_basic", - "MaxPool2dStaticCeilModeTrueModule_basic", - "ResNet18StaticModule_basic", - "ReduceAmaxKeepDim_basic", - "NativeLayerNormModule4D_basic", - "LayerNormNormalizeOverAllDimsModule_basic", - "Permute0RankModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ElementwiseLog2Module_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dFloatModule_basic", - "Threshold2dFloatModule_basic", - "Threshold3dFloatModule_basic", + "ElementwiseRelu6Module_basic", + "ElementwiseReluModule_basic", + "ElementwiseRemainderScalarModule_Float_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRsqrtModule_basic", + "ElementwiseSigmoidModule_basic", + "ElementwiseSignModule_basic", + "ElementwiseSinModule_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseSqrtModule_basic", + "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseMulScalarModule_basic", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", - "OnesModuleCPUDevice_basic", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", - "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", - "NewOnesModuleFloat2D_basic", - "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", - "SiluModule_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChangeStaticModule_basic", - "UnsafeViewExpandModule_basic", - "ReshapeCollapseModule_basic", - "ElementwiseErfModule_basic", - "ReshapeAsModule_basic", - "ElementwiseGeluModule_basic", - "GeluBackwardModule_basic", - "ElementwiseNeIntScalarModule_basic", - "Convolution2DStaticModule_basic", - "Convolution2DGroupsStatic_basic", - "ElementwiseNegModule_basic", - "TestMultipleTensorReturn_basic", - "TypeAsSameModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "BaddbmmDynamicModule_basic", - "BaddbmmStaticModule_basic", - "BaddbmmWithAlphaBetaModule_basic", - "BaddbmmWithAlphaModule_basic", - "BaddbmmWithBetaModule_basic", - "BaddbmmBroadcast1DInputModule_basic", - "BaddbmmBroadcast2DInputModule_basic", - "MatmulStaticBroadcast_basic", - "NumpyTRank0Module_basic", - "NumpyTRank1Module_basic", - "NumpyTRank2Module_basic", - "NumpyTRankNStaticModule_basic", - "NumpyTRankNDynamicModule_basic", - "EmbeddingModuleI32Static_basic", - "EmbeddingModule1DIndices_basic", - "TModuleRank2_basic", - "TransposeIntModule_basic", - "TransposeIntNegDimsModule_basic", - "ArgmaxModule_keepDim", - "ArgmaxModule_with_dim", - "_LogSoftmaxModuleStable_basic", - "ElementwiseAtenWhereSelfModule_basic", + "ElementwiseSubTensorInt8Module_basic", + "ElementwiseToDtypeIdentityModule_basic", + "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", - "MaskedFillScalarIntValueModule_basic", - "MaskedFillScalarIntValueStaticModule_basic", - "MaskedFillTensorIntValueStaticModule_basic", - "ElementwiseAddScalarInt64Module_basic", - "TensorLiteralModule_basic", - "NewZerosStaticModuleLayoutStrided_basic", - "TensorOpaqueLiteralModule_basic", - "TypePromotionDifferentCategoryModule_basic", - "TypePromotionSameCategoryDifferentWidthModule_basic", - "TypePromotionSameCategoryZeroRankWider_basic", - "TypePromotionZeroRankHigherCategoryModule_basic", + "ElementwiseWhereScalarModule_basic", + "EmbeddingModule1DIndices_basic", + "EmbeddingModuleI32Static_basic", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_uint8", + "EmptyStridedModule_basic", + "EyeStaticModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", + "Fill_TensorFloat64WithInt64Static_basic", + "FlattenRank0Module_basic", + "FlattenStaticModule_basic", + "FullLikeModuleFloat3DStatic_basic", + "FullLikeModuleInt2DStatic_basic", + "FullModuleDefaultDtype_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleFloat2D_basic", + "FullModuleFloat3D_basic", + "FullModuleInt2D_basic", + "FullModuleInt3D_basic", "GatherStaticModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", - "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + "GeluBackwardModule_basic", + "GluStaticModule_basic", + "HardsigmoidModule_basic", + "HardsigmoidRandomModule_basic", + "HardswishModule_basic", + "HardswishRandomModule_basic", + "HardtanhBackward_basic", + "HardTanhIntModule_basic", + "HardTanhModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DIntNonAccumulateModule_basic", "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", "IndexPutHackedTwin1DIntNonAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexTensorStaticModule_basic", - "IndexTensorMultiIndexStaticModule_basic", + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexTensorModule3dInputStatic_basic", - "ElementwiseWhereScalarModule_basic", - "FullLikeModuleFloat3DStatic_basic", - "FullModuleDefaultDtype_basic", - "FullModuleFloat3D_basic", - "FullModuleFalsePinMemory_basic", - "FullModuleInt2D_basic", - "NewFullModuleDefaultDtype_basic", - "NewFullModuleFalsePinMemory_basic", - "NewFullModuleFloat3DStatic_basic", - "NewFullModuleFloat3D_basic", + "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorStaticModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "LeakyReluBackwardModule_basic", + "LeakyReluBackwardStaticModule_basic", + "LiftFreshCopyModule_basic", + "_LogSoftmaxModule_basic", + "_LogSoftmaxModuleStable_basic", "MaskedFillScalarDefaultModule_basic", "MaskedFillScalarFloatValueModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", + "MaskedFillScalarIntValueModule_basic", + "MaskedFillScalarIntValueStaticModule_basic", + "MaskedFillTensorIntValueStaticModule_basic", + "Matmul_3d", + "Matmul4dStatic_basic", + "Matmul_dot", + "MatmulStaticBroadcast_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticModule_basic", + "MeanModule_basic", + "MmDagModule_basic", + "MoveDimIntModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "MseLossNoReductionModule_basic", + "NativeLayerNormModule4D_basic", + "NewEmptyModuleBool_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "NewFullModuleDefaultDtype_basic", + "NewFullModuleFalsePinMemory_basic", + "NewFullModuleFloat2D_basic", + "NewFullModuleFloat3D_basic", + "NewFullModuleFloat3DStatic_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt2DStatic_basic", + "NewFullModuleInt3D_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleFalsePinMemory_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosStaticModuleLayoutStrided_basic", + "NumpyTRank0Module_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNDynamicModule_basic", + "NumpyTRankNStaticModule_basic", "NumToTensorFloatModule_basic", - "LiftFreshCopyModule_basic", - "PrimsSumFloatModule_basic", + "NumToTensorIntModule_basic", + "OnesModuleCPUDevice_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleFloat_basic", + "OnesModuleInt_basic", + "PadModule_basic", + "PadWithNoneValModule_basic", + "Permute0RankModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsSqueezeModule_basic", - "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "PrimsSumFloatModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "ReduceAmaxKeepDim_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", "ReduceSumDimIntListKeepDimFloatModule_basic", "ReduceSumDimIntListKeepDimIntModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", - "ReduceSumElementTypeBoolModule_basic", - "ReduceSumDimIntListDtypeFloatModule_basic", - "ReduceSumDimIntListDtypeIntModule_basic", - "ReduceSumDimIntListElementTypeBoolModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "BroadcastToDifferentRankStaticModule_basic", - "BroadcastToDifferentRankNotOneStaticModule_basic", - "BroadcastToSameRankStaticModule_basic", - "BroadcastZeroRankInputStaticModule_basic", - "BroadcastListConstructWithMinusOneModule_basic", - "BroadcastDifferentRankWithMinusOneModule_basic", - "BroadcastDifferentRankSameFinalShapeModule_basic", - "SliceStaticModule_basic", - "SliceSizeTwoStepDivisibleStaticModule_basic", - "SliceOutOfLowerBoundStartIndexStaticModule_basic", - "ArangeStartStepIntModule_basic", - "ArangeDtypeFloatModule_basic", - "ArangeIntModule_basic", - "ArangeNegativeStartIntModule_basic", - "ArangeStartIntModule_basic", - "ArangeStartNegativeStepIntModule_basic", - "ArangeZeroElementOutputModule_basic", - "ArangeDtypeIntModule_basic", - "ArangeFalsePinMemoryModule_basic", - "ArangeFloatModule_basic", - "ArangeNegativeStartFloatModule_basic", - "ArangeStartFloatModule_basic", - "ArangeStartNegativeStepFloatModule_basic", - "ArangeStartStepFloatModule_basic", - "NumToTensorIntModule_basic", - "ToDtypeBoolLayoutNoneStaticModule_basic", - "ToCopyBoolDTypeStaticModule_basic", - "HardTanhIntModule_basic", - "AtenRoundIntModule_basic", - "MseLossNoReductionModule_basic", - "AddCMul_Module_basic", - "AddCDiv_Module_basic", - "TestF16Return_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseLeakyReluStaticModule_basic", - "LeakyReluBackwardModule_basic", - "LeakyReluBackwardStaticModule_basic", - "ElementwiseRelu6Module_basic", - "HardTanhModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", - "HardswishModule_basic", - "HardswishRandomModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", - "ElementwiseLeFloatTensorModule_basic", - "ElementwiseLeIntTensorModule_basic", - "FullLikeModuleInt2DStatic_basic", - "FullModuleInt3D_basic", - "FullModuleFloat2D_basic", - "ElementwiseAbsModule_basic", + "RepeatInterleaveFillModule_basic", + "RepeatInterleaveStaticModule_basic", "RepeatModule_basic", - "TensorsSplitTensorModule_basic", - "TensorsSplitTensorNegativeDimModule_basic", - "TensorsSplitTensorLastSmallerModule_basic", - "ConstantPad2dStaticModule_basic", - "ConstantPadNdModule_basic", - "ConstantPadNdPartialStaticModule_basic", - "ConstantPadNdStaticModule_basic", - "PadModule_basic", - "PadWithNoneValModule_basic", - "ElementwiseRemainderScalarModule_Float_basic", - "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "PrimsSqueezeModule_basic", - "PrimsSqueezeEmptyDimensionsModule_basic", - "MoveDimIntModule_basic", - "MoveDimIntNegativeIndexModule_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "MoveDimIntModule_basic", - "PrimsViewOfModule_basic", - "PrimsViewOfZeroRankModule_basic", - "DetachModule_basic", + "ReshapeAsModule_basic", + "ReshapeCollapseModule_basic", + "ResNet18StaticModule_basic", + "ReturnThreeTensorFloat32_basic", + "ReturnTwoTensorF32I64_basic", + "RsubFloatModule_basic", + "RsubFloatModule_noalpha_basic", + "RsubInt0d_NumToTensor_Module_basic", + "RsubIntModule_basic", + "RsubIntModule_noalpha_basic", + "RsubIntStaticModule_noalpha_basic", "ScalarTensorDefaultDtypeModule_basic", "ScalarTensorFloat32Module_basic", "ScalarTensorInt32Module_basic", "ScalarTensorInt64Module_basic", - "UnbindIntListUnpack_Module_basic", - "UnbindIntGetItem_Module_basic", - "TensorsConcatStaticModule_basic", - "TensorsConcatNegativeDimStaticModule_basic", - "TensorsConcatPromoteDTypeStaticModule_basic", - "AtenComplex64Module_basic", - "ElementwiseSqrtIntModule_basic", - "ElementwiseSqrtModule_basic", - "EmptyModule_defaultDtype", - "EmptyModule_int", - "EmptyModule_float", - "EmptyModule_contiguous", - "EmptyModule_falsePinMemory", - "NewEmptyModuleBool_basic", - "NewEmptyModuleDefaultDtype_basic", - "NewEmptyModuleLayoutIntDtype_basic", - "NewEmptyModuleFalsePinMemory_basic", - "NewEmptyModuleFloat2D_basic", - "NewEmptyModuleFloat3D_basic", - "NewEmptyModuleInt2D_basic", - "NewEmptyModuleInt3D_basic", - "NewEmptyModuleNonDefaultFloatDtype_basic", - "NewEmptyModuleNonDefaultIntDtype_basic", - "EmptyStridedModule_basic", - "NewEmptyStridedModuleDefaultDtype_basic", - "NewFullModuleInt2D_basic", - "NewFullModuleInt3D_basic", - "Fill_TensorFloat64WithInt64Static_basic", - "Fill_TensorFloat64WithFloat32Static_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SiluModule_basic", + "SliceOutOfLowerBoundStartIndexStaticModule_basic", + "SliceSizeTwoStepDivisibleStaticModule_basic", + "SliceStaticModule_basic", + "SoftmaxIntModule_basic", + "SoftmaxIntNegDimModule_basic", + "_SoftmaxModule_basic", "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic", "SplitTensorNegativeDimModule_basic", - "SplitTensorLastSmallerModule_basic", "SplitWithSizesListUnpackModule_basic", - "ChunkListUnpack_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "RepeatInterleaveStaticModule_basic", - "RepeatInterleaveFillModule_basic", + "SquareModule_basic", + "SqueezeDimModule_identity", + "SqueezeDimModule_static", + "SqueezeDimModule_unitDim", + "SqueezeModule_allUnitDim", + "SqueezeModule_broadcast", + "SqueezeModule_noUnitDim", + "SqueezeModule_static", + "TanhBackward_basic", + "TensorLiteralModule_basic", + "TensorOpaqueLiteralModule_basic", + "TensorsConcatNegativeDimStaticModule_basic", + "TensorsConcatPromoteDTypeStaticModule_basic", + "TensorsConcatStaticModule_basic", + "TensorsSplitTensorLastSmallerModule_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", + "TestF16Return_basic", + "TestMultipleTensorReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold2dFloatModule_basic", + "Threshold3dFloatModule_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "TModuleRank0_basic", + "TModuleRank1_basic", + "TModuleRank2_basic", + "ToCopyBoolDTypeStaticModule_basic", + "ToDtypeBoolLayoutNoneStaticModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", "TupleModule_basic", + "TypeAsSameModule_basic", + "TypePromotionAlphaWiderModule_basic", + "TypePromotionDifferentCategoryModule_basic", + "TypePromotionSameCategoryDifferentWidthModule_basic", + "TypePromotionSameCategoryZeroRankWider_basic", + "TypePromotionZeroRankHigherCategoryModule_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", - "NumpyTRank0Module_basic", - "Permute0RankModule_basic", - "Add_Module_basic", - "SoftmaxIntModule_basic", - "SoftmaxIntNegDimModule_basic", - "_LogSoftmaxModule_basic", - "_SoftmaxModule_basic", - "ElementwiseSubTensorInt8Module_basic", - "AtenEyeMModuleInt2D_basic", - "AtenEyeMModuleCPUDevice_basic", - "AtenEyeMModuleDefaultDtype_basic", - "AtenEyeMModuleFalsePinMemory_basic", - "AtenEyeMModuleFloat2D_basic", - "EyeStaticModule_basic", - "AtenEyeModuleInt2D_basic", - "AtenEyeModuleCPUDevice_basic", - "AtenEyeModuleDefaultDtype_basic", - "AtenEyeModuleFalsePinMemory_basic", - "AtenEyeModuleFloat2D_basic", - "MeanModule_basic", - "ArangeStartOutModule_basic", - "ArangeStartOutDtypeModule_basic", - "ArangeStartOutViewModule_basic", - "Conv2dBiasNoPaddingModule_basic", - "Conv2dNoPaddingModule_basic", - "Conv2dWithPaddingDilationStrideModule_basic", - "Conv2dWithPaddingModule_basic", + "UnsafeView1DFoldModule_basic", + "UnsafeViewExpandModule_basic", + "View1DFoldModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewDoubleMergeStaticModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandOnesMiddleOppModule_basic", + "ViewExpandOnesModule_basic", + "ViewFiveTestStaticModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleFalsePinMemory_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { @@ -1508,6 +1490,8 @@ "CumsumStaticNegativeDimModule_basic", "CumsumInputDtypeInt32Module_basic", "EyeStaticModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "NativeGroupNormBackwardModule_basic", "SliceWholeTensorModule_basic", "TensorFloatModule_basic", @@ -1522,8 +1506,6 @@ "NormalizeModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", }) - { ### Test failing in make_fx_tosa but not in tosa @@ -1553,8 +1535,13 @@ "CollapseStaticModule_basic", "CollapsePartialDynamicModule_basic", "CollapseFullDynamicModule_basic", + "SplitDimStaticModule_basic", + "SplitDimDynamicModule_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", diff --git a/projects/pt1/examples/torchscript_resnet_inference.ipynb b/projects/pt1/examples/torchscript_resnet_inference.ipynb index 82258fd39278..3ab7cc64dadb 100644 --- a/projects/pt1/examples/torchscript_resnet_inference.ipynb +++ b/projects/pt1/examples/torchscript_resnet_inference.ipynb @@ -92,8 +92,8 @@ "import torchvision\n", "\n", "import torch_mlir\n", - "from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder\n", - "from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations\n", + "from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder\n", + "from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations\n", "\n", "from torch_mlir.passmanager import PassManager\n", "from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend" diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index 73bd93f033db..6ed43a7317c8 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -1,5 +1,3 @@ -include(AddMLIRPython) - # Disables generation of "version soname" (i.e. libFoo.so.), which # causes pure duplication as part of Python wheels. set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON) @@ -17,8 +15,6 @@ add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") # PyTorch ################################################################################ -option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON) - if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) # Source builds set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) @@ -92,9 +88,6 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main # Lazy Tensor Core ################################################################################ -if(TORCH_MLIR_ENABLE_LTC) - add_subdirectory(torch_mlir/csrc/base_lazy_backend) -endif() # Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it # generates a dummy Python library when disabled. if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) @@ -106,7 +99,8 @@ endif() ################################################################################ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) - add_subdirectory(torch_mlir/dialects/torch/importer/jit_ir) + add_subdirectory(torch_mlir/jit_ir_importer) + add_subdirectory(torch_mlir/csrc/jit_ir_importer) add_subdirectory(torch_mlir_e2e_test) endif() diff --git a/projects/pt1/python/test/annotations-sugar.py b/projects/pt1/python/test/annotations-sugar.py index 98cbec74d1c5..e540e84b9e15 100644 --- a/projects/pt1/python/test/annotations-sugar.py +++ b/projects/pt1/python/test/annotations-sugar.py @@ -8,8 +8,8 @@ import torch from torch_mlir_e2e_test.annotations import annotate_args, export -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator -from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations +from torch_mlir.jit_ir_importer import ClassAnnotator +from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations class MmModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/__init__.py index 555642ac4947..f5a4f4fdf992 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/__init__.py @@ -19,13 +19,8 @@ from torch.fx.experimental.proxy_tensor import make_fx from .compiler_utils import run_pipeline_with_repro_report -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder -from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library -from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import ( - TOSA_TO_LINALG_FUNC_PIPELINE, - LinalgOnTensorsTosaBackend, - ) -from ._mlir_libs._mlir.ir import Module +from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder +from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library from .repro import reproduce from .compiler_utils import prepare_model, map_kwargs_into_args diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt similarity index 50% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt index 287e9a20c87b..5ae5ddf0a487 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt @@ -1,39 +1,3 @@ -# Sharp edge: Torch extensions need to use the same pybind11 that torch -# was compiled with, or else there will be issues in cross module exception -# handling (which will abort instead of raise). We circumvent the possibility -# by forcing the torch directories first. -include_directories(BEFORE - ${TORCH_INCLUDE_DIRS} - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_BINARY_DIR} - ${Python3_INCLUDE_DIRS} - ) -link_directories("${TORCH_INSTALL_PREFIX}/lib") - -# Static library with core functionality. -# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build) -# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376 -add_library(TorchMLIRJITIRImporter STATIC - class_annotator.cpp - function_importer.cpp - node_importer.cpp - ivalue_importer.cpp - torch_to_mlir_utils.cpp - ) -target_link_libraries(TorchMLIRJITIRImporter - TorchMLIRAggregateCAPI - ${TORCH_LIBRARIES} - ) -message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}") -set_target_properties(TorchMLIRJITIRImporter PROPERTIES - LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" - OUTPUT_NAME lib_jit_ir_importer - PREFIX "" - SUFFIX ".a" - CXX_VISIBILITY_PRESET "default" - COMPILE_FLAGS "${TORCH_CXXFLAGS}" - ) - # Separate Pybind MODULE due to issues with a SHARED library. # https://github.com/llvm/torch-mlir/issues/1154 add_library(TorchMLIRJITIRImporterPybind MODULE @@ -62,7 +26,6 @@ if(Python3_LIBRARIES) ) endif() -message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}") set_target_properties(TorchMLIRJITIRImporterPybind PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" OUTPUT_NAME _jit_ir_importer diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp similarity index 79% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp index 7d8525209d44..c1219d48d4d4 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.cpp @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// #include "class_annotator_pybind.h" -#include "class_annotator.h" +#include "jit_ir_importer/class_annotator.h" #include #include @@ -18,7 +18,7 @@ using namespace torch_mlir; static c10::ScalarType convertToC10ScalarType(py::object obj) { if (THPDtype_Check(obj.ptr())) { // Need reinterpret_cast, since no C++-level inheritance is involved. - THPDtype *dtype = reinterpret_cast(obj.ptr()); + THPDtype* dtype = reinterpret_cast(obj.ptr()); return dtype->scalar_type; } std::stringstream ss; @@ -48,16 +48,17 @@ static std::vector getArgAnnotations(py::list pyArgAnnotations) { return argAnnotations; } -void torch_mlir::initClassAnnotatorBindings(py::module &m) { +void torch_mlir::initClassAnnotatorBindings(py::module& m) { py::class_(m, "ClassAnnotator") .def(py::init<>()) .def("exportPath", &ClassAnnotator::exportPath) .def("exportNone", &ClassAnnotator::exportNone) - .def("annotateArgs", - [&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType, - std::vector path, py::list argAnnotations) { - cls_annotator.annotateArgs(rootClassType, path, - getArgAnnotations(argAnnotations)); - }) + .def( + "annotateArgs", + [&](ClassAnnotator& cls_annotator, c10::ClassType& rootClassType, + std::vector path, py::list argAnnotations) { + cls_annotator.annotateArgs( + rootClassType, path, getArgAnnotations(argAnnotations)); + }) .def("__repr__", &ClassAnnotator::toString); } diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.h similarity index 95% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.h index a0d1a75817ad..4eb170b8ba9a 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator_pybind.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.h @@ -18,7 +18,7 @@ namespace py = pybind11; namespace torch_mlir { -void initClassAnnotatorBindings(py::module &m); +void initClassAnnotatorBindings(py::module& m); } // namespace torch_mlir #endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.cpp similarity index 89% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.cpp index 2b90b3b65bff..a168ca1c05d3 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.cpp @@ -50,9 +50,9 @@ static py::list getRegisteredOps() { // since the JIT has its own dispatch mechanism that it uses to implement // "prim" ops and a handful of "aten" ops that are effectively prim ops, such // as `aten::__is__`. - for (const std::shared_ptr &op : + for (const std::shared_ptr& op : torch::jit::getAllOperators()) { - const c10::FunctionSchema &schema = op->schema(); + const c10::FunctionSchema& schema = op->schema(); py::dict record; { @@ -69,7 +69,7 @@ static py::list getRegisteredOps() { py::list arguments; py::list returns; - auto addArgument = [](py::list &container, const c10::Argument &arg) { + auto addArgument = [](py::list& container, const c10::Argument& arg) { py::dict argRecord; argRecord["name"] = arg.name(); argRecord["type"] = arg.type()->str(); @@ -87,10 +87,10 @@ static py::list getRegisteredOps() { py::dict aliasInfo; py::list before; py::list after; - for (auto &symbol : arg.alias_info()->beforeSets()) { + for (auto& symbol : arg.alias_info()->beforeSets()) { before.append(std::string(symbol.toQualString())); } - for (auto &symbol : arg.alias_info()->afterSets()) { + for (auto& symbol : arg.alias_info()->afterSets()) { after.append(std::string(symbol.toQualString())); } aliasInfo["is_write"] = arg.alias_info()->isWrite(); @@ -101,10 +101,10 @@ static py::list getRegisteredOps() { container.append(std::move(argRecord)); }; - for (auto &argument : schema.arguments()) { + for (auto& argument : schema.arguments()) { addArgument(arguments, argument); } - for (auto &returnArg : schema.returns()) { + for (auto& returnArg : schema.returns()) { addArgument(returns, returnArg); } record["arguments"] = std::move(arguments); @@ -115,6 +115,6 @@ static py::list getRegisteredOps() { return results; } -void torch_mlir::initGetRegisteredOpsBindings(py::module &m) { +void torch_mlir::initGetRegisteredOpsBindings(py::module& m) { m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring); } diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.h similarity index 94% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.h index ec336878c3c7..b2851e6a4208 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/get_registered_ops.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.h @@ -19,7 +19,7 @@ namespace torch_mlir { -void initGetRegisteredOpsBindings(py::module &m); +void initGetRegisteredOpsBindings(py::module& m); } // namespace torch_mlir diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp similarity index 61% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp index b072b0ed922c..94a47229dda7 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp @@ -8,17 +8,19 @@ //===----------------------------------------------------------------------===// #include "import_options_pybind.h" -#include "import_options.h" +#include "jit_ir_importer/import_options.h" namespace py = pybind11; using namespace torch_mlir; -void torch_mlir::initImportOptionsBindings(py::module &m) { +void torch_mlir::initImportOptionsBindings(py::module& m) { py::class_(m, "ImportOptions") .def(py::init<>()) - .def_readwrite("assumeTensorsHaveValueSemantics", - &ImportOptions::assumeTensorsHaveValueSemantics) - .def_readwrite("ignoreExistingTensorShapesAndDtypes", - &ImportOptions::ignoreExistingTensorShapesAndDtypes); + .def_readwrite( + "assumeTensorsHaveValueSemantics", + &ImportOptions::assumeTensorsHaveValueSemantics) + .def_readwrite( + "ignoreExistingTensorShapesAndDtypes", + &ImportOptions::ignoreExistingTensorShapesAndDtypes); } diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.h similarity index 92% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.h index 6e8e1389ca3a..4ca27a218584 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.h @@ -13,7 +13,7 @@ #include namespace torch_mlir { -void initImportOptionsBindings(pybind11::module &m); +void initImportOptionsBindings(pybind11::module& m); } // namespace torch_mlir #endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/init_python_bindings.cpp similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/init_python_bindings.cpp diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp similarity index 75% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp index ca4bd600f5ad..92f131b0d73b 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.cpp @@ -9,9 +9,9 @@ #include "module_builder.h" -#include "function_importer.h" -#include "ivalue_importer.h" -#include "mlir_utils.h" +#include "jit_ir_importer/function_importer.h" +#include "jit_ir_importer/ivalue_importer.h" +#include "jit_ir_importer/mlir_utils.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" @@ -22,7 +22,7 @@ namespace py = pybind11; using namespace torch_mlir; -static py::object getMlirIrClass(const char *className) { +static py::object getMlirIrClass(const char* className) { return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className); } @@ -33,7 +33,7 @@ static py::object createPythonContextIfNone(py::object contextObj) { return contextObj; } -static MlirContext castPythonObjectToMlirContext(py::object &contextObj) { +static MlirContext castPythonObjectToMlirContext(py::object& contextObj) { assert(!contextObj.is_none() && "context cannot be None"); auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR); MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr()); @@ -77,15 +77,15 @@ static void printDiagnostic(MlirDiagnostic diagnostic) { std::stringstream ss; ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic)) << ": "; - auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) { - auto *ssp = static_cast(stringCallbackUserData); + auto stringCallback = [](MlirStringRef s, void* stringCallbackUserData) { + auto* ssp = static_cast(stringCallbackUserData); ssp->write(s.data, s.length); }; - mlirDiagnosticPrint(diagnostic, stringCallback, static_cast(&ss)); + mlirDiagnosticPrint(diagnostic, stringCallback, static_cast(&ss)); // Use pybind11's print: // https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html - py::print(ss.str(), - py::arg("file") = py::module_::import("sys").attr("stderr")); + py::print( + ss.str(), py::arg("file") = py::module_::import("sys").attr("stderr")); } // Register a diagnostic handler that will redirect output to `sys.stderr` @@ -93,7 +93,7 @@ static void printDiagnostic(MlirDiagnostic diagnostic) { // that mlir diagnostics emitted are correctly routed in Jupyter notebooks. static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { auto diagnosticHandler = [](MlirDiagnostic diagnostic, - void *) -> MlirLogicalResult { + void*) -> MlirLogicalResult { printDiagnostic(diagnostic); for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) { printDiagnostic(mlirDiagnosticGetNote(diagnostic, i)); @@ -101,7 +101,7 @@ static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { return mlirLogicalResultSuccess(); }; MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( - context, diagnosticHandler, nullptr, [](void *) { return; }); + context, diagnosticHandler, nullptr, [](void*) { return; }); // Ignore the ID. We intend to keep this handler for the entire lifetime // of this context. (void)id; @@ -123,28 +123,28 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj) terminator = mlirBlockGetFirstOperation(getBodyBlock()); } -torch::jit::StrongFunctionPtr -ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function, - py::object maybeImportOptions) { +torch::jit::StrongFunctionPtr ModuleBuilder::importFunction( + torch::jit::StrongFunctionPtr function, py::object maybeImportOptions) { ImportOptions importOptions; if (!maybeImportOptions.is_none()) { importOptions = py::cast(maybeImportOptions); } MlirBlock block = getBodyBlock(); MlirOperation terminator = this->terminator; - MlirOperation func = importJitFunctionAsFuncOp(context, function.function_, - [](int) -> MlirAttribute { return {nullptr}; }, importOptions); + MlirOperation func = importJitFunctionAsFuncOp( + context, function.function_, + [](int) -> MlirAttribute { return {nullptr}; }, importOptions); mlirBlockInsertOwnedOperationBefore(block, terminator, func); return function; } -void ModuleBuilder::importModule(torch::jit::Module jitModule, - py::object maybeClassAnnotator, - py::object maybeImportOptions) { +void ModuleBuilder::importModule( + torch::jit::Module jitModule, py::object maybeClassAnnotator, + py::object maybeImportOptions) { ClassAnnotator dummyAnnotator; - ClassAnnotator *classAnnotator = &dummyAnnotator; + ClassAnnotator* classAnnotator = &dummyAnnotator; if (!maybeClassAnnotator.is_none()) { - classAnnotator = py::cast(maybeClassAnnotator); + classAnnotator = py::cast(maybeClassAnnotator); } ImportOptions importOptions; if (!maybeImportOptions.is_none()) { @@ -168,14 +168,15 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule, // precise `torch.class_type` names. // // This name is not semantically load-bearing!!! - auto &name = *jitModule.type()->name(); + auto& name = *jitModule.type()->name(); auto debugModuleNameAttr = mlirStringAttrGet( context, toMlirStringRef(name.atoms()[name.atoms().size() - 1])); - mlirOperationSetAttributeByName(mlirModuleGetOperation(module), - toMlirStringRef("torch.debug_module_name"), - debugModuleNameAttr); - importIValue(jitModule._ivalue(), mlirModuleGetBody(module), - mlirModuleGetContext(module), *classAnnotator, importOptions); + mlirOperationSetAttributeByName( + mlirModuleGetOperation(module), + toMlirStringRef("torch.debug_module_name"), debugModuleNameAttr); + importIValue( + jitModule._ivalue(), mlirModuleGetBody(module), + mlirModuleGetContext(module), *classAnnotator, importOptions); } MlirBlock ModuleBuilder::getBodyBlock() { @@ -183,14 +184,16 @@ MlirBlock ModuleBuilder::getBodyBlock() { return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0)); } -void ModuleBuilder::bind(py::module &m) { +void ModuleBuilder::bind(py::module& m) { py::class_(m, "ModuleBuilder") .def(py::init(), py::arg("context") = py::none()) .def_property_readonly("context", &ModuleBuilder::getContextObj) .def_property_readonly("module", &ModuleBuilder::getModuleObj) - .def("import_function", &ModuleBuilder::importFunction, py::arg("function"), - py::arg("importOptions") = py::none()) - .def("import_module", &ModuleBuilder::importModule, py::arg("module"), - py::arg("classAnnotator") = py::none(), - py::arg("importOptions") = py::none()); + .def( + "import_function", &ModuleBuilder::importFunction, + py::arg("function"), py::arg("importOptions") = py::none()) + .def( + "import_module", &ModuleBuilder::importModule, py::arg("module"), + py::arg("classAnnotator") = py::none(), + py::arg("importOptions") = py::none()); } diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h similarity index 84% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h rename to projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h index 08695e15faf3..cff2200d365a 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h +++ b/projects/pt1/python/torch_mlir/csrc/jit_ir_importer/module_builder.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H #define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H -#include "class_annotator.h" +#include "jit_ir_importer/class_annotator.h" #include "mlir-c/IR.h" @@ -29,7 +29,7 @@ class ModuleBuilder { ModuleBuilder(pybind11::object contextObj); /// Creates Python bindings for the class. - static void bind(pybind11::module &m); + static void bind(pybind11::module& m); pybind11::object getContextObj() { return contextObj; } pybind11::object getModuleObj() { return moduleObj; } @@ -38,16 +38,15 @@ class ModuleBuilder { // torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr. // Just a bit of naming cruft. // Returns the same function, making it suitable as a nested decorator. - torch::jit::StrongFunctionPtr - importFunction(torch::jit::StrongFunctionPtr function, - py::object maybeImportOptions); + torch::jit::StrongFunctionPtr importFunction( + torch::jit::StrongFunctionPtr function, py::object maybeImportOptions); // Imports a torch::jit::Module into the current module, using the // annotations, if not none, provided in `maybeClassAnnotator` which should be // a ClassAnnotator. - void importModule(torch::jit::Module jitModule, - py::object maybeClassAnnotator, - py::object maybeImportOptions); + void importModule( + torch::jit::Module jitModule, py::object maybeClassAnnotator, + py::object maybeImportOptions); private: MlirBlock getBodyBlock(); diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt index fdef27143728..1c1f2fa2a43b 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt @@ -1,28 +1,3 @@ -########################################################################### -# Setup PyTorch -########################################################################### - -include(TorchMLIRPyTorch) - -TorchMLIRProbeForPyTorchInstall() -if(TORCH_MLIR_USE_INSTALLED_PYTORCH) - TorchMLIRConfigurePyTorch() -else() - # Assume it is a sibling to the overall project. - set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch") - message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}") -endif() - -find_package(Torch 1.11 REQUIRED) - -########################################################################### -# Setup Python development -########################################################################### - -list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/externals/llvm-project/mlir/cmake/modules") -include(MLIRDetectPythonEnv) -mlir_configure_python_dev_packages() - ########################################################################### # Library definition ########################################################################### diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 1064a3d1e1ac..4bcb9347b5aa 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -14,12 +14,12 @@ #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include "backend_impl.h" diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp index c575d9dd299b..f4b8cd9ba579 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -11,10 +11,10 @@ #include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/backend/backend_interface.h" -#include -#include -#include -#include +#include +#include +#include +#include #include #include diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt deleted file mode 100644 index 30bb4cb3151a..000000000000 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -#------------------------------------------------------------------------------- -# Setup PyTorch -#------------------------------------------------------------------------------- - -include(TorchMLIRPyTorch) - -TorchMLIRProbeForPyTorchInstall() -if(TORCH_MLIR_USE_INSTALLED_PYTORCH) - TorchMLIRConfigurePyTorch() -else() - # Assume it is a sibling to the overall project. - set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch") - message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}") -endif() - -find_package(Torch 1.11 REQUIRED) - -message(STATUS "libtorch_python CXXFLAGS is ...${TORCH_CXXFLAGS}") -#------------------------------------------------------------------------------- -# Subdirectories -#------------------------------------------------------------------------------- - -add_subdirectory(csrc) - -## Declare the sources of the Python module. - -declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources - SOURCES_GLOB - dialects/torch/importer/jit_ir/*.py -) diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/__init__.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt new file mode 100644 index 000000000000..c2883b3dca84 --- /dev/null +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt @@ -0,0 +1,12 @@ +#------------------------------------------------------------------------------- +# Subdirectories +#------------------------------------------------------------------------------- + +## Declare the sources of the Python module. + +declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES_GLOB + jit_ir_importer/*.py +) diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/__init__.py b/projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py similarity index 75% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/__init__.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py index ead98dd5c6db..b5a49561ade0 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/__init__.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py @@ -11,8 +11,11 @@ # Our native extension is not self-contained. It references libraries which # must come in via the above first. -from ....._mlir_libs._jit_ir_importer import * +from .._mlir_libs._jit_ir_importer import * +# Ensure that the torch dialect has been loaded as it registers passes +# and other things the jit_ir_importer needs. +from ..dialects import torch as _unused_torch_dialect __all__ = [ "debug_trace_to_stderr", diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/__init__.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/__init__.py similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/__init__.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/__init__.py diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py similarity index 97% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 86dd9d44c5b6..c18817070a2d 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -200,9 +200,15 @@ def aten〇log_softmax〇int〡shape(self: List[int], dim: int, dtype: Optional[ def aten〇clamp〡shape(self: List[int], min: Optional[float] = None, max: Optional[float] = None) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇clamp〇Tensor〡shape(self: List[int], min: Optional[List[int]] = None, max: Optional[List[int]] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇clamp_min〡shape(self: List[int], min: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇clamp_min〇Tensor〡shape(self: List[int], min: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]: return upstream_shape_functions.unary(self) @@ -246,6 +252,24 @@ def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: return collapsed +def prims〇split_dim〡shape(a: List[int], dim: int, outer_length: int) -> List[int]: + assert dim >=0, "'dim' must be non-negative" + assert dim < len(a), "'dim' must be less than the rank of the tensor" + assert outer_length > 0, "'outer_length' must be positive" + assert a[dim] % outer_length == 0, "'outer_length' must divide the size of the dimension, a[dim]" + + split: List[int] = [] + for i in range(dim): + split.append(a[i]) + + split.append(outer_length) + split.append(a[dim] // outer_length) + + for i in range(dim + 1, len(a)): + split.append(a[i]) + + return split + def aten〇to〇dtype〡shape(self: List[int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -433,6 +457,10 @@ def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.argmax(self, dim, keepdim) +def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: + # There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here. + return upstream_shape_functions.argmax(self, dim, keepdim) + # TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor, # making it impossible to add support for it using the current design of the shape library. def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]: @@ -446,6 +474,10 @@ def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) - reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) return reduced_shape, reduced_shape +def aten〇min〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]: + reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) + return reduced_shape, reduced_shape + def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) @@ -915,6 +947,9 @@ def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> Li def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_left_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇bitwise_not〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -1511,6 +1546,11 @@ def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: else: return torch.double +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, outer_length=1)) +def prims〇split_dim〡dtype(a_rank_dtype: Tuple[int, int], dim: int, outer_length: int) -> int: + _, a_dtype = a_rank_dtype + return a_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1720,6 +1760,15 @@ def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, f return torch.int64 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇clamp_min〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], min_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + min_rank, min_dtype = min_rank_dtype + ranks: List[Optional[int]] = [self_rank, min_rank] + dtypes = [self_dtype, min_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1)) def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float, complex]] = None, max: Optional[Union[int, float, complex]] = None) -> int: self_rank, self_dtype = self_rank_dtype @@ -1727,6 +1776,23 @@ def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[i return torch.int64 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=3)) +def aten〇clamp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], min_rank_dtype: Optional[Tuple[int, int]] = None, max_rank_dtype: Optional[Tuple[int, int]] = None) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank] + dtypes = [self_dtype] + if min_rank_dtype is not None: + min_rank, min_dtype = min_rank_dtype + ranks.append(min_rank) + dtypes.append(min_dtype) + if max_rank_dtype is not None: + max_rank, max_dtype = max_rank_dtype + ranks.append(max_rank) + dtypes.append(max_dtype) + if len(ranks) > 1: + return promote_dtypes(ranks, dtypes) + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype @@ -2479,6 +2545,14 @@ def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_left_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + # Different width @@ -3264,7 +3338,10 @@ def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Li @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: - self_rank, self_dtype = self_rank_dtype + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇argmin〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: return torch.int64 @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) @@ -3300,6 +3377,10 @@ def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: return aten〇max〡dtype(self_rank_dtype), torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: + return aten〇min〡dtype(self_rank_dtype), torch.int64 + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, @@ -3683,6 +3764,19 @@ def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) dtypes.append(tensor_dtype) return promote_dtypes(ranks, dtypes) +@check_dtype_function( + [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.int32)]),]) +def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int]], path: Optional[List[int]] = None) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.int64 diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/library_generator.py similarity index 99% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/library_generator.py index 74eb520e22d4..6cd19643a5f0 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/library_generator.py @@ -10,7 +10,7 @@ import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder from torch_mlir.passmanager import PassManager from .registry import Registry diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/registry.py diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py similarity index 99% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 53e0c3e416f7..fab101525bd3 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -102,7 +102,7 @@ def _get_main_module_name() -> str: // This file is automatically generated. Please do not edit. // Generated via: // ``` -// python -m {_get_main_module_name()} +// build_tools/update_torch_ods.sh // ``` // //===----------------------------------------------------------------------===// @@ -274,10 +274,10 @@ def emit_with_mutating_variants(key, **kwargs): "aten::exp : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", + "aten::acos : (Tensor) -> (Tensor)", "aten::atan : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::asin : (Tensor) -> (Tensor)", - "aten::acos : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", @@ -302,7 +302,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", - "aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", "aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)", "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", @@ -320,6 +319,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)", @@ -339,6 +339,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") @@ -567,6 +568,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") + emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)") emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") @@ -693,7 +695,7 @@ def emit_with_mutating_variants(key, **kwargs): # List ops. emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True) - emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True) + emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) @@ -829,6 +831,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") emit("prims::collapse : (Tensor, int, int) -> (Tensor)") + emit("prims::split_dim : (Tensor, int, int) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/utils.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/utils.py similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/utils.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/utils.py diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/torchscript_annotations.py b/projects/pt1/python/torch_mlir/jit_ir_importer/torchscript_annotations.py similarity index 97% rename from projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/torchscript_annotations.py rename to projects/pt1/python/torch_mlir/jit_ir_importer/torchscript_annotations.py index d495dda4836f..a6541b6503b1 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/torchscript_annotations.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/torchscript_annotations.py @@ -8,7 +8,7 @@ import torch import torch_mlir -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator +from torch_mlir.jit_ir_importer import ClassAnnotator # Decorators diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 1c37ac38aa22..2d7147955053 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -17,6 +17,7 @@ "RepeatInterleaveModule_basic", "Im2ColModule_basic", "ElementwiseClampIntModule_basic", + "ReduceMinAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", } @@ -42,7 +43,6 @@ def register_all_tests(): from . import type_conversion from . import backprop from . import reduction - from . import argmax from . import matmul from . import reshape_like from . import scalar diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/argmax.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/argmax.py deleted file mode 100644 index 098ed508b63c..000000000000 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/argmax.py +++ /dev/null @@ -1,65 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import torch - -from torch_mlir_e2e_test.framework import TestUtils -from torch_mlir_e2e_test.registry import register_test_case -from torch_mlir_e2e_test.annotations import annotate_args, export - -# ============================================================================== - -class ArgmaxModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - - def forward(self, a): - return torch.argmax(a) - - -@register_test_case(module_factory=lambda: ArgmaxModule()) -def ArgmaxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4)) - -# ============================================================================== - -class ArgmaxWithDimModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, a): - return torch.argmax(a, dim=1) - -@register_test_case(module_factory=lambda: ArgmaxWithDimModule()) -def ArgmaxModule_with_dim(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) - -# ============================================================================== - -class ArgmaxKeepDimsModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - def forward(self, a): - return torch.argmax(a, 0, True) - -@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule()) -def ArgmaxModule_keepDim(module, tu: TestUtils): - module.forward(tu.rand(4, 6)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 64b3d0bbdb32..0d371fe37008 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -686,7 +686,56 @@ def forward(self, x): def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) +# ============================================================================== + + +class PixelShuffleModuleFullDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1,-1,-1,-1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleFullDynamic()) +def PixelShuffleModuleFullDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(1,8,3,3, low = 0, high = 100)) + +# ============================================================================== + +class PixelShuffleModuleSpatiallyDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2,1,8,-1,-1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyDynamic()) +def PixelShuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(2,1,8,2,3, low = 0, high = 100)) + + +# ============================================================================== + +class PixelShuffleModuleSpatiallyStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1,-1,-1,3,1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyStatic()) +def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): + module.forward(tu.randint(1,2,12,3,1, low = 0, high = 100)) + + +# ============================================================================== class TensorsConcatModule(torch.nn.Module): @@ -861,6 +910,28 @@ def TensorsStackModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsStackSingleElementListModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.stack([x], dim=1) + + +@register_test_case(module_factory=lambda: TensorsStackSingleElementListModule()) +def TensorsStackSingleElementListModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 32)) + + +# ============================================================================== + + class TensorsStackNegativeDimModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 66bab6ce01fc..a3cf7d525251 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -451,6 +451,25 @@ def EmptyModule_int(module, tu: TestUtils): module.forward() +class EmptyUInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + empty = torch.ops.aten.empty([1], dtype=torch.uint8) + return torch.ops.aten.zeros_like(empty).to(torch.int8) + + +@register_test_case(module_factory=lambda: EmptyUInt8Module()) +def EmptyModule_uint8(module, tu: TestUtils): + module.forward() + + class EmptyFloatModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 9446c81333cd..d79706dfc9dd 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -894,6 +894,106 @@ def ElementwiseClampMaxModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampTensorFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, x, min, max): + min_clamp = torch.clamp(x, min) + max_clamp = torch.clamp(x, max=max) + both_clamp = torch.clamp(x, min=min, max=max) + return min_clamp, max_clamp, both_clamp + + +@register_test_case(module_factory=lambda: ElementwiseClampTensorFloatModule()) +def ElementwiseClampTensorFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10), torch.tensor([-5.0]), torch.tensor([5.0])) + + +# ============================================================================== + + +class ElementwiseClampTensorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, x, min, max): + min_clamp = torch.clamp(x, min) + max_clamp = torch.clamp(x, max=max) + both_clamp = torch.clamp(x, min=min, max=max) + return min_clamp, max_clamp, both_clamp + + +@register_test_case(module_factory=lambda: ElementwiseClampTensorIntModule()) +def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, low=-10, high=10), torch.tensor([-5]), torch.tensor([5])) + + +# ============================================================================== + + +class ElementwiseClampMinTensorFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, x, min): + return torch.ops.aten.clamp_min(x, min=min) + + +@register_test_case(module_factory=lambda: ElementwiseClampMinTensorFloatModule()) +def ElementwiseClampMinTensorFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10), torch.tensor([-5.0])) + + +# ============================================================================== + + +class ElementwiseClampMinTensorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, x, min): + return torch.ops.aten.clamp_min(x, min=min) + + +@register_test_case(module_factory=lambda: ElementwiseClampMinTensorIntModule()) +def ElementwiseClampMinTensorIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, low=-10, high=10), torch.tensor([-5])) + + +# ============================================================================== + + class RsubFloatModule(torch.nn.Module): def __init__(self): @@ -2938,10 +3038,50 @@ def forward(self, a): def ElementwiseCosIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) +# ============================================================================== + + +class ElementwiseAcosModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.acos(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcosModule()) +def ElementwiseAcosModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) # ============================================================================== +class ElementwiseAcosIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.acos(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcosIntModule()) +def ElementwiseAcosIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + +# ============================================================================== + class ElementwiseNegModule(torch.nn.Module): def __init__(self): @@ -3389,6 +3529,52 @@ def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils): # ============================================================================== +class TriuModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4,5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.triu(x, 1) + + +@register_test_case(module_factory=lambda: TriuModule()) +def TriuModule_basic(module, tu: TestUtils): + x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2], + [-0.2447, 0.9556, -1.2919, 1.3378, 0.3], + [ 0.4333, 0.3146, 0.6576, -1.0432, 0.4], + [-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]]) + module.forward(x) + + +# ============================================================================== + + +class TriuBroadcastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3,4,5,6], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.triu(x, 2) + + +@register_test_case(module_factory=lambda: TriuBroadcastModule()) +def TriuBroadcastModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3,4,5,6)) + + +# ============================================================================== + + class AtenTriuWithNegDiagonalModule(torch.nn.Module): def __init__(self): @@ -3787,6 +3973,69 @@ def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseBitwiseLeftShiftInt64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_left_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt64Module()) +def ElementwiseBitwiseLeftShiftInt64Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64)) + + +class ElementwiseBitwiseLeftShiftInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 4], torch.int32, True), + ([-1, 1], torch.int32, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_left_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt32Module()) +def ElementwiseBitwiseLeftShiftInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32)) + + +class ElementwiseBitwiseLeftShiftInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_left_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt8Module()) +def ElementwiseBitwiseLeftShiftInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8)) + + +# ============================================================================== + + class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 57a549309c4d..ac04eeb41109 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -160,7 +160,9 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseGeFloatTensorModule()) def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) # ============================================================================== @@ -200,7 +202,9 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule()) def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) # ============================================================================== @@ -378,6 +382,28 @@ def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLeFloatTensorNanModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.le(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseLeFloatTensorNanModule()) +def ElementwiseLeFloatTensorNanModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) + +# ============================================================================== + class ElementwiseLeIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -414,7 +440,9 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule()) def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index e40086bb7188..e59279ab57f7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -225,4 +225,40 @@ def forward(self, m, v): @register_test_case(module_factory=lambda: Mv()) def Mv_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 2), tu.rand(2)) \ No newline at end of file + module.forward(tu.rand(2, 2), tu.rand(2)) + +# ============================================================================== + +class AtenMmFloatTypes(torch.nn.Module): + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.mm(a, b) + + +@register_test_case(module_factory=lambda: AtenMmFloatTypes()) +def AtenMmFloatTypes_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 8), tu.rand(8, 8)) + +# ============================================================================== + +class AtenMmIntTypes(torch.nn.Module): + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, a, b): + return torch.ops.aten.mm(a, b) + + +@register_test_case(module_factory=lambda: AtenMmIntTypes()) +def AtenMmIntTypes_basic(module, tu: TestUtils): + module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 06159324b304..585a68e55af4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -354,6 +354,117 @@ def ReduceMaxAlongDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceMinAlongDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a, 1)[0] + + +@register_test_case(module_factory=lambda: ReduceMinAlongDim()) +def ReduceMinAlongDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + +class ReduceMinAlongDimSignedInt(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a, 1) + + +@register_test_case(module_factory=lambda: ReduceMinAlongDimSignedInt()) +def ReduceMinAlongDimSignedInt_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) + +# ============================================================================== + +class ReduceMinAlongDimUnsignedInt(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.uint8, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a, 1) + + +@register_test_case(module_factory=lambda: ReduceMinAlongDimUnsignedInt()) +def ReduceMinAlongDimUnsignedInt_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-100, high=100).to(torch.uint8)) + +# ============================================================================== + +class ReduceMinAlongDimNegative(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a, 1)[0] + + +@register_test_case(module_factory=lambda: ReduceMinAlongDimNegative()) +def ReduceMinAlongDimNegative_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64)) + +# ============================================================================== + +class ReduceMinKeepDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a, 1, keepdim=True)[1] + + +@register_test_case(module_factory=lambda: ReduceMinKeepDim()) +def ReduceMinKeepDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + +# ============================================================================== + +class ReduceMinKeepDimReturnBoth(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.min(a, 1, keepdim=True) + +@register_test_case(module_factory=lambda: ReduceMinKeepDimReturnBoth()) +def ReduceMinKeepDimReturnBoth_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, low=-10, high=-5)) + +# ============================================================================== + class ReduceMaxAlongDimSignedInt(torch.nn.Module): def __init__(self): super().__init__() @@ -663,6 +774,171 @@ def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== + +class ArgminModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmin(a) + + +@register_test_case(module_factory=lambda: ArgminModule()) +def ArgminModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ArgminIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmin(a) + + +@register_test_case(module_factory=lambda: ArgminIntModule()) +def ArgminIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + +@register_test_case(module_factory=lambda: ArgminIntModule()) +def ArgminIntModule_multiple_mins(module, tu: TestUtils): + # To cover the special case that the minimal value occurs more than once. + # The pytorch convention is here to consider the first occurence as the argmin. + module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64)) + +# ============================================================================== + +class ArgminWithDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmin(a, dim=1) + +@register_test_case(module_factory=lambda: ArgminWithDimModule()) +def ArgminModule_with_dim(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ArgminKeepDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmin(a, 0, True) + +@register_test_case(module_factory=lambda: ArgminKeepDimsModule()) +def ArgminModule_keepDim(module, tu: TestUtils): + module.forward(tu.rand(4, 6)) + +# ============================================================================== + +class ArgmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmax(a) + + +@register_test_case(module_factory=lambda: ArgmaxModule()) +def ArgmaxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ArgmaxIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + + def forward(self, a): + return torch.ops.aten.argmax(a) + + +@register_test_case(module_factory=lambda: ArgmaxIntModule()) +def ArgmaxIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + +@register_test_case(module_factory=lambda: ArgmaxIntModule()) +def ArgmaxIntModule_multiple_maxs(module, tu: TestUtils): + # To cover the special case that the maximal value occurs more than once. + # The pytorch convention is here to consider the first occurence as the argmax. + module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64)) + +# ============================================================================== + +class ArgmaxWithDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmax(a, dim=1) + +@register_test_case(module_factory=lambda: ArgmaxWithDimModule()) +def ArgmaxModule_with_dim(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ArgmaxKeepDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.argmax(a, 0, True) + +@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule()) +def ArgmaxModule_keepDim(module, tu: TestUtils): + module.forward(tu.rand(4, 6)) + +# ============================================================================== + class ReduceL1NormModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index a0ee6221b9d1..a73435c3c1ad 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -122,6 +122,46 @@ def forward(self, a): def ViewDynamicExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 30, 384)) + +# ============================================================================== + + +class SplitDimStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([12], torch.float32, True)]) + + def forward(self, a): + return torch.ops.prims.split_dim(a, 0, 4) + +@register_test_case( + module_factory=lambda: SplitDimStaticModule()) +def SplitDimStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(12)) + +class SplitDimDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True)]) + + def forward(self, a): + return torch.ops.prims.split_dim(a, 0, 3) + +@register_test_case( + module_factory=lambda: SplitDimDynamicModule()) +def SplitDimDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,5)) + + + # ============================================================================== # class CollapseAllDimensionsModule(torch.nn.Module): @@ -1004,3 +1044,59 @@ def forward(self, inputs): @register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule()) def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) + +# ============================================================================== + +class EinsumStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 2, 4], torch.float32, True), + ([5, 4, 6], torch.float32, True), + ([3, 7, 6], torch.float32, True), + ]) + def forward(self, tensor1, tensor2, tensor3): + return torch.ops.aten.einsum('bqe,ked,btd->bqtk', [tensor1, tensor2, tensor3]) + +@register_test_case(module_factory=lambda: EinsumStaticModule()) +def EinsumStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4), tu.rand(5, 4, 6), tu.rand(3, 7, 6)) + + +class EinsumStaticFourDimensionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5, 6], torch.float32, True), + ([3, 7, 5, 6], torch.float32, True), + ]) + def forward(self, tensor1, tensor2): + return torch.ops.aten.einsum('blhd,bshd->blhs', [tensor1, tensor2]) + +@register_test_case(module_factory=lambda: EinsumStaticFourDimensionModule()) +def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6)) + + +class EinsumStaticContractRhsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5], torch.float32, True), + ([4, 5], torch.float32, True), + ]) + def forward(self, tensor1, tensor2): + return torch.ops.aten.einsum('abc,bc->a', [tensor1, tensor2]) + +@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule()) +def EinsumStaticContractRhsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) \ No newline at end of file diff --git a/projects/pt1/test/lit.cfg.py b/projects/pt1/test/lit.cfg.py index a9753bf22719..31e3ee388f34 100644 --- a/projects/pt1/test/lit.cfg.py +++ b/projects/pt1/test/lit.cfg.py @@ -19,7 +19,7 @@ # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'TORCH_MLIR' +config.name = 'TORCH_MLIR_PT1' config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py index 7c448f6e3bbf..26eaa5bd0cb1 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder # RUN: %PYTHON %s | FileCheck %s mb = ModuleBuilder() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py index e8bcd4864f10..6cc2d57b1caa 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py index ce235a6bf03b..3a2ed4319d24 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/class-annotator-repr.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder # RUN: %PYTHON %s | FileCheck %s mb = ModuleBuilder() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py index cc4b5656b13a..2a0806f6fff2 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder # RUN: %PYTHON %s | FileCheck %s mb = ModuleBuilder() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py index cc2963d46782..79b4dccd208e 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-recursive.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py index 37b5d48ad52f..433f8249b1e6 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py index f4ad4dd3a3c4..399b45f73353 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py index 0a9e7f9265fd..117b0cff9586 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/dict.py @@ -5,7 +5,7 @@ from typing import Dict, Optional import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py index ade43aca0ade..318e099758c6 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions-that-call-methods.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py index 484260617575..ee22a495efa8 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/functions.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py index 2e8765be40a2..0c1b8f2ffddc 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py index 6a941330d039..fee1b2922f0c 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py index 7eb98beb9c1d..5d38d6e3a111 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py index fc246c458e84..0143012bf2b0 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/methods.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py index 9bd66c97c125..eae86ec1c94b 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: not %PYTHON %s 2>&1 | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py index a3ce3440c88e..968509accea0 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: not %PYTHON %s 2>&1 | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py index 25d65101486b..4c323ec01e41 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py index 253bdfcec3e7..0f6516a2734c 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py index 55fed3299e96..e48c327ed2f9 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/prim.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py index 3bcfb07173f9..3cb8cf992d33 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py index f05cf434f837..d77b98323e27 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # UNSUPPORTED: system-darwin # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py index d7d94bd9031a..b65d6f5ca038 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py index b0834691e746..5b2cf04b5545 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py index 92333d20e1db..d9983628d92f 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py index e57c20fe59ee..36dfa32f0360 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py index 831c619adc58..31a89e3e1e46 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tensors.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py index 3b0bf2d4ea69..7bed706ac600 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/classes.py b/projects/pt1/test/python/importer/jit_ir/node_import/classes.py index 511aac690277..09e2b1b0b4ac 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/classes.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/classes.py @@ -6,7 +6,7 @@ import torch from torch._C import CompilationUnit -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder import typing diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py index f7b441a12da0..bb6ab4ce4dae 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/dict.py b/projects/pt1/test/python/importer/jit_ir/node_import/dict.py index ed4371bb0147..0060357b4fca 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/dict.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/dict.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder import collections from typing import Tuple, Optional, List, NamedTuple, Dict diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py index 3a9d3a3211e2..71853b0c0b04 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/errors.py b/projects/pt1/test/python/importer/jit_ir/node_import/errors.py index be0479dcd8a5..2ac801bddea4 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/errors.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/errors.py @@ -5,7 +5,7 @@ import enum import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder class Color(enum.Enum): diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py b/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py index e245ec870b58..a724f118547a 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py @@ -2,7 +2,7 @@ # This file is licensed under a pytorch-style license # See LICENSE.pytorch for license information. -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder from utils import create_script_function diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py index 94eed3cefdbc..89f5604bf752 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder import typing diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/if.py b/projects/pt1/test/python/importer/jit_ir/node_import/if.py index fd8a7267e46c..8289e05031c5 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/if.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/if.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/list.py b/projects/pt1/test/python/importer/jit_ir/node_import/list.py index 9a09914e3ed6..2b30d545b4c4 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/list.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/list.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py index e21f4c8c0b51..d6bb141f25d7 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder import typing diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py index 2565c6c41861..07a56616efa2 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder +from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder from utils import create_script_function diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py index 8e14b677f236..2dff435cd422 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder from typing import Tuple, Optional, NamedTuple from utils import create_script_function diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py index f08fba24c405..8da5e0e2cc13 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py index eae6b4578cee..a0e86a66ae2d 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py @@ -3,7 +3,7 @@ # See LICENSE.pytorch for license information. import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/union.py b/projects/pt1/test/python/importer/jit_ir/node_import/union.py index 691a8e413442..14eb41a217c6 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/union.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/union.py @@ -5,7 +5,7 @@ from typing import Union import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.jit_ir_importer import ModuleBuilder # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s diff --git a/pytorch-hash.txt b/pytorch-hash.txt index c73e8b8ff023..2caf78c61ce4 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -a5a404865c01f86881f6b3ab0cd9a562d0b420de +a111e45dfe64cd565b2c0369b683f67d6658d2cc diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir new file mode 100644 index 000000000000..397d72a4896b --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -0,0 +1,436 @@ +// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + +// CHECK-LABEL: func.func @test_abs +func.func @test_abs(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.abs %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Abs"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add +func.func @test_add(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add_bcast +func.func @test_add_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add_uint8 +func.func @test_add_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>, !torch.int -> !torch.vtensor<[3,4,5],ui8> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> + return %0 : !torch.vtensor<[3,4,5],ui8> +} + +// CHECK-LABEL: @test_and_bcast3v1d +func.func @test_and_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.And"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: @test_argmax_default_axis_example +func.func @test_argmax_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 0 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,2],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> + return %0 : !torch.vtensor<[1,2],si64> +} + +// CHECK-LABEL: @test_argmax_negative_axis_keepdims_example +func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: @test_argmax_no_keepdims_example +func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool false + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// CHECK-LABEL: @test_argmin_default_axis_example +func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 0 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> + return %0 : !torch.vtensor<[1,2],si64> +} + +// CHECK-LABEL: @test_argmin_negative_axis_keepdims_example +func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: @test_argmin_no_keepdims_example +func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool false + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// CHECK-LABEL: @test_atan +func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Atan"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_acos +func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.acos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Acos"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_bitshift_left_uint8 +func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> + return %0 : !torch.vtensor<[3],ui8> +} + +// CHECK-LABEL: @test_bitshift_left_uint16 +func.func @test_bitshift_left_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> + return %0 : !torch.vtensor<[3],ui16> +} + +// CHECK-LABEL: @test_bitshift_left_uint32 +func.func @test_bitshift_left_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> + return %0 : !torch.vtensor<[3],ui32> +} + +// CHECK-LABEL: @test_bitshift_left_uint64 +func.func @test_bitshift_left_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> + return %0 : !torch.vtensor<[3],ui64> +} + +// CHECK-LABEL: @test_bitshift_right_uint8 +func.func @test_bitshift_right_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> + return %0 : !torch.vtensor<[3],ui8> +} + +// CHECK-LABEL: @test_bitshift_right_uint16 +func.func @test_bitshift_right_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> + return %0 : !torch.vtensor<[3],ui16> +} + +// CHECK-LABEL: @test_bitshift_right_uint32 +func.func @test_bitshift_right_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> + return %0 : !torch.vtensor<[3],ui32> +} + +// CHECK-LABEL: @test_bitshift_right_uint64 +func.func @test_bitshift_right_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64> + %0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> + return %0 : !torch.vtensor<[3],ui64> +} + +// CHECK-LABEL: @test_bitwise_and_i16_3d +func.func @test_bitwise_and_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16> + %0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> + return %0 : !torch.vtensor<[3,4,5],si16> +} + +// CHECK-LABEL: @test_bitwise_and_i32_2d +func.func @test_bitwise_and_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> + %0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> + return %0 : !torch.vtensor<[3,4],si32> +} + +// CHECK-LABEL: @test_bitwise_and_ui8_bcast_4v3d +func.func @test_bitwise_and_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> + %0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> + return %0 : !torch.vtensor<[3,4,5,6],ui8> +} + +// CHECK-LABEL: @test_bitwise_or_i16_4d +func.func @test_bitwise_or_i16_4d(%arg0: !torch.vtensor<[3,4,5,6],si8>, %arg1: !torch.vtensor<[3,4,5,6],si8>) -> !torch.vtensor<[3,4,5,6],si8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],si8>, !torch.vtensor<[3,4,5,6],si8> -> !torch.vtensor<[3,4,5,6],si8> + %0 = torch.operator "onnx.BitwiseOr"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],si8>, !torch.vtensor<[3,4,5,6],si8>) -> !torch.vtensor<[3,4,5,6],si8> + return %0 : !torch.vtensor<[3,4,5,6],si8> +} + +// CHECK-LABEL: @test_bitwise_or_i32_2d +func.func @test_bitwise_or_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> + %0 = torch.operator "onnx.BitwiseOr"(%arg0, %arg1) : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> + return %0 : !torch.vtensor<[3,4],si32> +} + +// CHECK-LABEL: @test_bitwise_or_ui8_bcast_4v3d +func.func @test_bitwise_or_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> + %0 = torch.operator "onnx.BitwiseOr"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> + return %0 : !torch.vtensor<[3,4,5,6],ui8> +} + +// CHECK-LABEL: @test_bitwise_not_2d +func.func @test_bitwise_not_2d(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> + %0 = torch.operator "onnx.BitwiseNot"(%arg0) : (!torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> + return %0 : !torch.vtensor<[3,4],si32> +} + +// CHECK-LABEL: @test_bitwise_not_4d +func.func @test_bitwise_not_4d(%arg0: !torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> + %0 = torch.operator "onnx.BitwiseNot"(%arg0) : (!torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> + return %0 : !torch.vtensor<[3,4,5,6],ui8> +} + +// CHECK-LABEL: @test_bitwise_xor_i16_3d +func.func @test_bitwise_xor_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16> + %0 = torch.operator "onnx.BitwiseXor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> + return %0 : !torch.vtensor<[3,4,5],si16> +} + +// CHECK-LABEL: @test_bitwise_xor_i32_2d +func.func @test_bitwise_xor_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> + %0 = torch.operator "onnx.BitwiseXor"(%arg0, %arg1) : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> + return %0 : !torch.vtensor<[3,4],si32> +} + +// CHECK-LABEL: @test_bitwise_xor_ui8_bcast_4v3d +func.func @test_bitwise_xor_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> + %0 = torch.operator "onnx.BitwiseXor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> + return %0 : !torch.vtensor<[3,4,5,6],ui8> +} + +// CHECK-LABEL: @test_cast_BFLOAT16_to_FLOAT +func.func @test_cast_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3,4],bf16>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// CHECK-LABEL: @test_cast_DOUBLE_to_FLOAT +func.func @test_cast_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// CHECK-LABEL: @test_cast_DOUBLE_to_FLOAT16 +func.func @test_cast_DOUBLE_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 5 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f16> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 10 : si64} : (!torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f16> + return %0 : !torch.vtensor<[3,4],f16> +} + +// CHECK-LABEL: @test_cast_FLOAT_to_BFLOAT16 +func.func @test_cast_FLOAT_to_BFLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],bf16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 15 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],bf16> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 16 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],bf16> + return %0 : !torch.vtensor<[3,4],bf16> +} + +// CHECK-LABEL: @test_cast_FLOAT_to_DOUBLE +func.func @test_cast_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 7 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 11 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f64> + return %0 : !torch.vtensor<[3,4],f64> +} + +// CHECK-LABEL: @test_cast_FLOAT_to_FLOAT16 +func.func @test_cast_FLOAT_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 5 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f16> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 10 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f16> + return %0 : !torch.vtensor<[3,4],f16> +} + +// CHECK-LABEL: @test_cast_FLOAT16_to_DOUBLE +func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 7 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 11 : si64} : (!torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f64> + return %0 : !torch.vtensor<[3,4],f64> +} + +// CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT +func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// CHECK-LABEL: @test_ceil_example +func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> + %0 = torch.operator "onnx.Ceil"(%arg0) : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> + return %0 : !torch.vtensor<[2],f32> +} + +// CHECK-LABEL: @test_ceil +func.func @test_ceil(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Ceil"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_clip_default_int8_min +func.func @test_clip_default_int8_min(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si8>, !torch.vtensor<[],si8> -> !torch.vtensor<[3,4,5],si8> + %0 = torch.operator "onnx.Clip"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si8>, !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> + return %0 : !torch.vtensor<[3,4,5],si8> +} + +// CHECK-LABEL: @test_clip_default_min +func.func @test_clip_default_min(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Clip"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_clip_example +func.func @test_clip_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Clip"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: @test_clip +func.func @test_clip(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Clip"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_cos_example +func.func @test_cos_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Cos"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: @test_cos +func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Cos"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_div_bcast +func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_div_example +func.func @test_div_example(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> + %0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> + return %0 : !torch.vtensor<[2],f32> +} + +// CHECK-LABEL: @test_div +func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_div_uint8 +func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8> + %0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> + return %0 : !torch.vtensor<[3,4,5],ui8> +} + +// CHECK-LABEL: @test_equal_bcast +func.func @test_equal_bcast(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[5],si32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[5],si32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[5],si32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: @test_equal +func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: @test_floor_example +func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Floor"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: @test_floor +func.func @test_floor(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Floor"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir new file mode 100644 index 000000000000..22d5e2d35183 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir @@ -0,0 +1,18 @@ +// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch + +module { + func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // TODO: Unsupported torch.onnx.select_last_index + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> + return %0 : !torch.vtensor<[2,4],si64> + } +} + +// ----- +func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // TODO: Unsupported torch.onnx.select_last_index + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 470962e2494d..eba7546655e9 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -40,6 +40,17 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- +// CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned( +// CHECK: linalg.matmul_unsigned +func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32> + attributes {torch.assume_strict_symbolic_shapes} +{ + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32> + return %0 : !torch.vtensor<[?,2],ui32> +} + +// ----- + // If the operands are missing dtype, we cannot lower it. func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { // expected-error@+1 {{failed to legalize}} diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index e02d946aa5b7..83424a17d843 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -5,11 +5,9 @@ // CHECK-LABEL: func.func @torch.aten.view$twotothree( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> -// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<3x2xf32> to tensor<3x2xf32> -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32> -// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x3xf32> to tensor<2x3xf32> -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32> func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { @@ -18,13 +16,14 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[3,2],f32>, !torch.list -> !torch.vtensor<[2,3],f32> return %1 : !torch.vtensor<[2,3],f32> - } +} + +// ----- // CHECK-LABEL: func.func @torch.aten.view$dynamictest( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor to tensor -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[BUILTIN_TENSOR]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -35,7 +34,29 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> return %3 : !torch.vtensor<[?,?],f32> - } +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$dynamictest2( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2], [3]] : tensor into tensor +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor -> !torch.vtensor<[?,2,3,?],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32> + +func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %2 = torch.aten.size.int %arg0, %int2 : !torch.vtensor<[?,6,?],f32>, !torch.int -> !torch.int + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,6,?],f32>, !torch.int -> !torch.int + %1 = torch.prim.ListConstruct %0, %int2, %int3, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %1 : !torch.vtensor<[?,6,?],f32>, !torch.list -> !torch.vtensor<[?,2,3,?], f32> + return %3 : !torch.vtensor<[?,2,3,?], f32> +} + +// ----- // CHECK-LABEL: func.func @torch.aten.view$dynamicVal( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> { @@ -43,8 +64,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] : tensor<16x128xf32> into tensor<16x1x128xf32> -// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<16x1x128xf32> to tensor<16x1x128xf32> -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32> func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> { @@ -54,16 +74,58 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> ! %0 = torch.prim.ListConstruct %int16, %int1, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[16,1,128],f32> return %1 : !torch.vtensor<[16,1,128],f32> - } +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten$dynamicValOutput( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2]] : tensor<4x5x6xf32> into tensor<120xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] : tensor<120xf32> into tensor<8x1x15x1xf32> +// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<8x1x15x1xf32> to tensor<8x1x?x1xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<8x1x?x1xf32> -> !torch.vtensor<[8,1,?,1],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[8,1,?,1],f32> + +func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> { + %int8 = torch.constant.int 8 + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int8, %int1, %int-1, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[4,5,6],f32>, !torch.list -> !torch.vtensor<[8,1,?,1],f32> + return %1 : !torch.vtensor<[8,1,?,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten$dynamicValOutput2( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2]] : tensor<4x5x6xf32> into tensor<4x30xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2], [3, 4]] : tensor<4x30xf32> into tensor<2x1x2x3x10xf32> +// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x1x2x3x10xf32> to tensor<2x1x2x3x?xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<2x1x2x3x?xf32> -> !torch.vtensor<[2,1,2,3,?],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,1,2,3,?],f32> + +// 4 -> [2,1,2] [5,6] -> [3,10]. +func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> { + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int2, %int1, %int2, %int3, %int-1 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[4,5,6],f32>, !torch.list -> !torch.vtensor<[2,1,2,3,?],f32> + return %1 : !torch.vtensor<[2,1,2,3,?],f32> +} + +// ----- // CHECK-LABEL: func.func @torch.aten.view$expandInferredDim( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32> -// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<2x6xf32> to tensor<2x6xf32> -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] : tensor<12xf32> into tensor<3x2x2xf32> -// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<3x2x2xf32> to tensor<3x2x2xf32> -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32> func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> { @@ -73,4 +135,123 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - %0 = torch.prim.ListConstruct %int3, %int2, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[2,6],f32>, !torch.list -> !torch.vtensor<[3,2,2],f32> return %1 : !torch.vtensor<[3,2,2],f32> - } +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$singleUnknownMatches0( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,3,?,2,3],f32> -> tensor<10x3x?x2x3xf32> +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3, 4]] : tensor<10x3x?x2x3xf32> into tensor<30x?x6xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4]] : tensor<30x?x6xf32> into tensor<2x3x5x?x6xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x3x5x?x6xf32> -> !torch.vtensor<[2,3,5,?,6],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3,5,?,6],f32> + +// [10,3,?,2,3] -> [30,?,6] -> [2,3,5,?,6] +// Associations are, +// -- for collapse, [0,1], [2], [3,4] and +// -- for expand [0,1,2], [3], [4]. +func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int6 = torch.constant.int 6 + %int5 = torch.constant.int 5 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int2, %int3, %int5, %int-1, %int6 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[10,3,?,2,3],f32>, !torch.list -> !torch.vtensor<[2,3,5,?,6],f32> + return %1 : !torch.vtensor<[2,3,5,?,6],f32> +} + +// ----- + +// Multiple aspects of decomposition here: +// 1) an expand from (8) to (2,2,2) +// 2) a collapse from (2,1,3) to (6) +// 3) a single unknown dim matching in the middle. +// 4) on either side of the unkown dim (3), another unkown dim, +// but one which matches between the input and the output + +// CHECK: func.func @torch.aten.view$combineConcepts( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32> +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2], [3], [4, 5, 6]] : tensor<8x?x?x?x2x1x3xf32> into tensor<8x?x?x?x6xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4], [5], [6]] : tensor<8x?x?x?x6xf32> into tensor<2x2x2x?x?x?x6xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32> + +func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> { + + %int1 = torch.constant.int 1 + %size1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[8,?,?,?,2,1,3], f32>, !torch.int -> !torch.int + + %int3 = torch.constant.int 3 + %size3 = torch.aten.size.int %arg0, %int3 : !torch.vtensor<[8,?,?,?,2,1,3], f32>, !torch.int -> !torch.int + + %int2 = torch.constant.int 2 + %int6 = torch.constant.int 6 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int2, %int2, %int2, %size1, %int-1, %size3, %int6 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>, !torch.list -> !torch.vtensor<[2,2,2,?,?,?,6], f32> + return %1 : !torch.vtensor<[2,2,2,?,?,?,6], f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$multiDynamicsInSourceOfCollapse +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,2,?,4,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,2,?,4,?],f32> -> tensor +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2, 3, 4]] : tensor into tensor +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[COLLAPSE]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?],f32> +func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtensor<[?,2,?,4,?], f32>) -> !torch.vtensor<[?], f32> { + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,2,?,4,?], f32>, !torch.list -> !torch.vtensor<[?], f32> + return %1 : !torch.vtensor<[?], f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$castingView +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3,4,5],f32> { + +// The current lowring only succeeds if the input (arg0) has shape [3,4,5], +// determined at runtime. This is a bit limiting, and we'll probably want to +// improve that in the future. For now we check that there are 2 runtime +// asserts on the sizes of dimensions 0 and 1 (size of dimension 2 implied). + +// CHECK-COUNT-2: cf.assert {{.*}} "mismatching contracting dimension +// CHECK: return {{.*}} : !torch.vtensor<[3,4,5],f32> + +func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> !torch.vtensor<[3,4,5], f32> { + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int3, %int4, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?,?], f32>, !torch.list -> !torch.vtensor<[3,4,5], f32> + return %1 : !torch.vtensor<[3,4,5], f32> +} + +// ----- + +// A function with a torch.view op, going from shape (10,?,2,3) to (2,5,?,6). +// We expect this to lower to a collapse with [0], [1], [2,3] followed by +// an expand with [0,1], [2], [3]: +// CHECK: func.func @torch.aten.view$dynamicInferredSame( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,?,2,3],f32> -> tensor<10x?x2x3xf32> +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<10x?x2x3xf32> into tensor<10x?x6xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1], [2], [3]] : tensor<10x?x6xf32> into tensor<2x5x?x6xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x5x?x6xf32> -> !torch.vtensor<[2,5,?,6],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,5,?,6],f32> + +func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> { + %int2 = torch.constant.int 2 + %int5 = torch.constant.int 5 + %int6 = torch.constant.int 6 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int2, %int5, %int-1, %int6 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[10,?,2,3],f32>, !torch.list -> !torch.vtensor<[2,5,?,6],f32> + return %1 : !torch.vtensor<[2,5,?,6],f32> +} + diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 82535062a720..5dfd8daa9d44 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2136,3 +2136,13 @@ func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !t %0 = torch.aten.numel %arg0 : !torch.vtensor<[3,4],f32> -> !torch.int return %0 : !torch.int } + +// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor$canonicalize +// CHECK-NEXT: torch.constant.float -1.000000e+09 +// CHECK-NEXT: torch.aten.masked_fill.Scalar +// CHECK-NEXT: return +func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.vtensor.literal(dense<-1.000000e+09> : tensor) : !torch.vtensor<[],f32> + %1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 9aec26662b69..3fe94d0417e1 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -24,11 +24,11 @@ func.func @basic(%arg0: !torch.vtensor) -> !torch.vtensor { // ----- -// CHECK-LABEL: func.func private @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes( +// CHECK-LABEL: func.func private @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes( // CHECK: {{.*}} = torch.promote_dtypes {{.*}} : (!torch.list>, !torch.list) -> !torch.int // CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide( -// CHECK: {{.*}} = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes({{.*}} +// CHECK: {{.*}} = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes({{.*}} // CHECK-LABEL: func.func @op_with_dtype_promotion( // CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide({{.*}} diff --git a/utils/bazel/WORKSPACE.bazel b/utils/bazel/WORKSPACE.bazel index f7a81a4faf29..351ba301cf93 100644 --- a/utils/bazel/WORKSPACE.bazel +++ b/utils/bazel/WORKSPACE.bazel @@ -33,6 +33,10 @@ llvm_configure( }, targets = [ "X86", + # The bazel dependency graph for mlir-opt fails to load (at the analysis step) without the NVPTX + # target in this list, because mlir/test:TestGPU depends on the //llvm:NVPTXCodeGen target, + # which is not defined unless this is included. + "NVPTX", ], ) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index bdf9c9fc72f9..2a9edaac503c 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -282,6 +282,31 @@ gentbl_cc_library( ], ) +td_library( + name = "TorchMLIRConversionTorchOnnxToTorchPassTdFiles", + srcs = [ + "include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td", + ], + includes = ["include"], +) + +gentbl_cc_library( + name = "TorchMLIRConversionTorchOnnxToTorchPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-pass-decls"], + "include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td", + deps = [ + ":TorchMLIRConversionTorchOnnxToTorchPassTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + # TorchConversion transforms td_library( name = "TorchMLIRTorchConversionPassesTdFiles", @@ -454,6 +479,22 @@ cc_library( ], ) +cc_library( + name = "TorchMLIRTorchOnnxToTorch", + srcs = glob([ + "lib/Conversion/TorchOnnxToTorch/*.h", + "lib/Conversion/TorchOnnxToTorch/*.cpp", + ]), + hdrs = glob(["include/torch-mlir/Conversion/TorchOnnxToTorch/*.h"]), + strip_include_prefix = "include", + deps = [ + ":TorchMLIRConversionTorchOnnxToTorchPassIncGen", + ":TorchMLIRTorchDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + cc_library( name = "TorchMLIRConversionPasses", srcs = [ @@ -468,6 +509,7 @@ cc_library( strip_include_prefix = "include", deps = [ ":TorchMLIRTorchConversionToMLProgram", + ":TorchMLIRTorchOnnxToTorch", ":TorchMLIRTorchToArith", ":TorchMLIRTorchToLinalg", ":TorchMLIRTorchToSCF",